Spaces:
Running
on
Zero
Running
on
Zero
XXXXRT666
commited on
Commit
·
d4d21ad
1
Parent(s):
d0754c2
Init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +37 -0
- .gitignore +196 -0
- .pre-commit-config.yaml +15 -0
- GPT_SoVITS/Accelerate/MLX/__init__.py +12 -0
- GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py +181 -0
- GPT_SoVITS/Accelerate/MLX/backends/mlx_static.py +99 -0
- GPT_SoVITS/Accelerate/MLX/backends/mlx_varlen.py +103 -0
- GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py +65 -0
- GPT_SoVITS/Accelerate/MLX/structs_mlx.py +152 -0
- GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py +238 -0
- GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py +530 -0
- GPT_SoVITS/Accelerate/PyTorch/__init__.py +30 -0
- GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py +158 -0
- GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py +166 -0
- GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py +175 -0
- GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py +166 -0
- GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py +145 -0
- GPT_SoVITS/Accelerate/PyTorch/export.py +467 -0
- GPT_SoVITS/Accelerate/PyTorch/nn.py +69 -0
- GPT_SoVITS/Accelerate/PyTorch/sample_funcs.py +67 -0
- GPT_SoVITS/Accelerate/PyTorch/structs.py +151 -0
- GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py +223 -0
- GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py +672 -0
- GPT_SoVITS/Accelerate/__init__.py +30 -0
- GPT_SoVITS/Accelerate/logger.py +203 -0
- GPT_SoVITS/configs/.gitignore +1 -0
- GPT_SoVITS/configs/s2.json +91 -0
- GPT_SoVITS/configs/s2v2Pro.json +91 -0
- GPT_SoVITS/configs/s2v2ProPlus.json +91 -0
- GPT_SoVITS/eres2net/ERes2NetV2.py +252 -0
- GPT_SoVITS/eres2net/fusion.py +27 -0
- GPT_SoVITS/eres2net/kaldi.py +844 -0
- GPT_SoVITS/eres2net/pooling_layers.py +101 -0
- GPT_SoVITS/f5_tts/model/__init__.py +3 -0
- GPT_SoVITS/f5_tts/model/backbones/README.md +20 -0
- GPT_SoVITS/f5_tts/model/backbones/dit.py +193 -0
- GPT_SoVITS/f5_tts/model/backbones/mmdit.py +144 -0
- GPT_SoVITS/f5_tts/model/backbones/unett.py +218 -0
- GPT_SoVITS/f5_tts/model/modules.py +665 -0
- GPT_SoVITS/feature_extractor/__init__.py +3 -0
- GPT_SoVITS/feature_extractor/cnhubert.py +46 -0
- GPT_SoVITS/inference_webui.py +1104 -0
- GPT_SoVITS/module/attentions.py +658 -0
- GPT_SoVITS/module/attentions_onnx.py +385 -0
- GPT_SoVITS/module/commons.py +185 -0
- GPT_SoVITS/module/core_vq.py +365 -0
- GPT_SoVITS/module/data_utils.py +1073 -0
- GPT_SoVITS/module/losses.py +70 -0
- GPT_SoVITS/module/mel_processing.py +142 -0
- GPT_SoVITS/module/models.py +1411 -0
.gitattributes
CHANGED
|
@@ -1 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
GPT_SoVITS/text/ja_userdic/userdict.csv filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
GPT_SoVITS/text/G2PWModel/* filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
GPT_SoVITS/text/G2PWModel/** filter=lfs diff=lfs merge=lfs -text
|
| 38 |
GPT_SoVITS/text/ja_userdic/userdict.csv filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
| 2 |
+
.vscode
|
| 3 |
+
__pycache__
|
| 4 |
+
*.pyc
|
| 5 |
+
env
|
| 6 |
+
runtime
|
| 7 |
+
.idea
|
| 8 |
+
output
|
| 9 |
+
logs
|
| 10 |
+
SoVITS_weights*/
|
| 11 |
+
GPT_weights*/
|
| 12 |
+
TEMP
|
| 13 |
+
weight.json
|
| 14 |
+
ffmpeg*
|
| 15 |
+
ffprobe*
|
| 16 |
+
cfg.json
|
| 17 |
+
speakers.json
|
| 18 |
+
ref_audios
|
| 19 |
+
tools/AP_BWE/24kto48k/*
|
| 20 |
+
!tools/AP_BWE/24kto48k/readme.txt
|
| 21 |
+
onnx_export
|
| 22 |
+
|
| 23 |
+
# Byte-compiled / optimized / DLL files
|
| 24 |
+
__pycache__/
|
| 25 |
+
*.py[cod]
|
| 26 |
+
*$py.class
|
| 27 |
+
|
| 28 |
+
# C extensions
|
| 29 |
+
*.so
|
| 30 |
+
|
| 31 |
+
# Distribution / packaging
|
| 32 |
+
.Python
|
| 33 |
+
build/
|
| 34 |
+
develop-eggs/
|
| 35 |
+
dist/
|
| 36 |
+
downloads/
|
| 37 |
+
eggs/
|
| 38 |
+
.eggs/
|
| 39 |
+
lib/
|
| 40 |
+
lib64/
|
| 41 |
+
parts/
|
| 42 |
+
sdist/
|
| 43 |
+
var/
|
| 44 |
+
wheels/
|
| 45 |
+
share/python-wheels/
|
| 46 |
+
*.egg-info/
|
| 47 |
+
.installed.cfg
|
| 48 |
+
*.egg
|
| 49 |
+
MANIFEST
|
| 50 |
+
|
| 51 |
+
# PyInstaller
|
| 52 |
+
# Usually these files are written by a python script from a template
|
| 53 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 54 |
+
*.manifest
|
| 55 |
+
*.spec
|
| 56 |
+
|
| 57 |
+
# Installer logs
|
| 58 |
+
pip-log.txt
|
| 59 |
+
pip-delete-this-directory.txt
|
| 60 |
+
|
| 61 |
+
# Unit test / coverage reports
|
| 62 |
+
htmlcov/
|
| 63 |
+
.tox/
|
| 64 |
+
.nox/
|
| 65 |
+
.coverage
|
| 66 |
+
.coverage.*
|
| 67 |
+
.cache
|
| 68 |
+
nosetests.xml
|
| 69 |
+
coverage.xml
|
| 70 |
+
*.cover
|
| 71 |
+
*.py,cover
|
| 72 |
+
.hypothesis/
|
| 73 |
+
.pytest_cache/
|
| 74 |
+
cover/
|
| 75 |
+
|
| 76 |
+
# Translations
|
| 77 |
+
*.mo
|
| 78 |
+
*.pot
|
| 79 |
+
|
| 80 |
+
# Django stuff:
|
| 81 |
+
*.log
|
| 82 |
+
local_settings.py
|
| 83 |
+
db.sqlite3
|
| 84 |
+
db.sqlite3-journal
|
| 85 |
+
|
| 86 |
+
# Flask stuff:
|
| 87 |
+
instance/
|
| 88 |
+
.webassets-cache
|
| 89 |
+
|
| 90 |
+
# Scrapy stuff:
|
| 91 |
+
.scrapy
|
| 92 |
+
|
| 93 |
+
# Sphinx documentation
|
| 94 |
+
docs/_build/
|
| 95 |
+
|
| 96 |
+
# PyBuilder
|
| 97 |
+
.pybuilder/
|
| 98 |
+
target/
|
| 99 |
+
|
| 100 |
+
# Jupyter Notebook
|
| 101 |
+
.ipynb_checkpoints
|
| 102 |
+
|
| 103 |
+
# IPython
|
| 104 |
+
profile_default/
|
| 105 |
+
ipython_config.py
|
| 106 |
+
|
| 107 |
+
# pyenv
|
| 108 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 109 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 110 |
+
# .python-version
|
| 111 |
+
|
| 112 |
+
# pipenv
|
| 113 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 114 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 115 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 116 |
+
# install all needed dependencies.
|
| 117 |
+
#Pipfile.lock
|
| 118 |
+
|
| 119 |
+
# UV
|
| 120 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 121 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 122 |
+
# commonly ignored for libraries.
|
| 123 |
+
#uv.lock
|
| 124 |
+
|
| 125 |
+
# poetry
|
| 126 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 127 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 128 |
+
# commonly ignored for libraries.
|
| 129 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 130 |
+
#poetry.lock
|
| 131 |
+
|
| 132 |
+
# pdm
|
| 133 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 134 |
+
#pdm.lock
|
| 135 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 136 |
+
# in version control.
|
| 137 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 138 |
+
.pdm.toml
|
| 139 |
+
.pdm-python
|
| 140 |
+
.pdm-build/
|
| 141 |
+
|
| 142 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 143 |
+
__pypackages__/
|
| 144 |
+
|
| 145 |
+
# Celery stuff
|
| 146 |
+
celerybeat-schedule
|
| 147 |
+
celerybeat.pid
|
| 148 |
+
|
| 149 |
+
# SageMath parsed files
|
| 150 |
+
*.sage.py
|
| 151 |
+
|
| 152 |
+
# Environments
|
| 153 |
+
.env
|
| 154 |
+
.venv
|
| 155 |
+
env/
|
| 156 |
+
venv/
|
| 157 |
+
ENV/
|
| 158 |
+
env.bak/
|
| 159 |
+
venv.bak/
|
| 160 |
+
|
| 161 |
+
# Spyder project settings
|
| 162 |
+
.spyderproject
|
| 163 |
+
.spyproject
|
| 164 |
+
|
| 165 |
+
# Rope project settings
|
| 166 |
+
.ropeproject
|
| 167 |
+
|
| 168 |
+
# mkdocs documentation
|
| 169 |
+
/site
|
| 170 |
+
|
| 171 |
+
# mypy
|
| 172 |
+
.mypy_cache/
|
| 173 |
+
.dmypy.json
|
| 174 |
+
dmypy.json
|
| 175 |
+
|
| 176 |
+
# Pyre type checker
|
| 177 |
+
.pyre/
|
| 178 |
+
|
| 179 |
+
# pytype static type analyzer
|
| 180 |
+
.pytype/
|
| 181 |
+
|
| 182 |
+
# Cython debug symbols
|
| 183 |
+
cython_debug/
|
| 184 |
+
|
| 185 |
+
# PyCharm
|
| 186 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 187 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 188 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 189 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 190 |
+
#.idea/
|
| 191 |
+
|
| 192 |
+
# Ruff stuff:
|
| 193 |
+
.ruff_cache/
|
| 194 |
+
|
| 195 |
+
# PyPI configuration file
|
| 196 |
+
.pypirc
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ci:
|
| 2 |
+
autoupdate_schedule: monthly
|
| 3 |
+
|
| 4 |
+
repos:
|
| 5 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
| 6 |
+
rev: v0.11.7
|
| 7 |
+
hooks:
|
| 8 |
+
# Run the linter.
|
| 9 |
+
- id: ruff
|
| 10 |
+
types_or: [ python, pyi ]
|
| 11 |
+
args: [ --fix , "--exit-zero" ]
|
| 12 |
+
# Run the formatter.
|
| 13 |
+
- id: ruff-format
|
| 14 |
+
types_or: [ python, pyi ]
|
| 15 |
+
args: [ --line-length, "120", --target-version, "py310" ]
|
GPT_SoVITS/Accelerate/MLX/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.util
|
| 2 |
+
import platform
|
| 3 |
+
|
| 4 |
+
if importlib.util.find_spec("mlx") is not None and platform.system() == "Darwin":
|
| 5 |
+
from .sample_funcs_mlx import sample_naive as sample_naive_mlx
|
| 6 |
+
from .t2s_engine_mlx import T2SEngine as T2SEngineMLX
|
| 7 |
+
|
| 8 |
+
backends = ["mlx_static", "mlx_quantized_mxfp4", "mlx_quantized_affine", "mlx_varlen"]
|
| 9 |
+
else:
|
| 10 |
+
backends = []
|
| 11 |
+
|
| 12 |
+
__all__ = ["T2SEngineMLX", "sample_naive_mlx", "backends"]
|
GPT_SoVITS/Accelerate/MLX/backends/mlx_quantized.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import cast
|
| 4 |
+
|
| 5 |
+
import mlx.core as mx
|
| 6 |
+
import mlx.nn as nn
|
| 7 |
+
|
| 8 |
+
from ..structs_mlx import KVCacheQ
|
| 9 |
+
from ..t2s_model_abc import (
|
| 10 |
+
AttentionABC,
|
| 11 |
+
KVCache,
|
| 12 |
+
KVCacheHND,
|
| 13 |
+
T2SDecoderABC,
|
| 14 |
+
TransformerBlockABC,
|
| 15 |
+
TransformerDecoderABC,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
Array = mx.array
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Attention(AttentionABC):
|
| 22 |
+
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
|
| 23 |
+
super().__init__(n_head, hidden_dim, max_seq_length)
|
| 24 |
+
self.kc_class = KVCacheHND
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def quantized_scaled_dot_product_attention(
|
| 28 |
+
queries: Array,
|
| 29 |
+
q_keys: tuple[Array, Array, Array],
|
| 30 |
+
q_values: tuple[Array, Array, Array],
|
| 31 |
+
scale: float,
|
| 32 |
+
mask: Array,
|
| 33 |
+
group_size: int = 32,
|
| 34 |
+
bits: int = 8,
|
| 35 |
+
) -> Array:
|
| 36 |
+
queries *= scale
|
| 37 |
+
|
| 38 |
+
scores = mx.quantized_matmul(queries, *q_keys, transpose=True, group_size=group_size, bits=bits)
|
| 39 |
+
scores = mx.where(mask, scores, -mx.inf)
|
| 40 |
+
scores = mx.softmax(scores, axis=-1, precise=True) # type: ignore
|
| 41 |
+
out = mx.quantized_matmul(scores, *q_values, transpose=False, group_size=group_size, bits=bits)
|
| 42 |
+
|
| 43 |
+
return out
|
| 44 |
+
|
| 45 |
+
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
| 46 |
+
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
|
| 47 |
+
|
| 48 |
+
q, k, v = self.in_proj(x).split(3, axis=-1)
|
| 49 |
+
|
| 50 |
+
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
| 51 |
+
|
| 52 |
+
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
| 53 |
+
|
| 54 |
+
kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
|
| 55 |
+
assert len(kv_cache) == 2
|
| 56 |
+
|
| 57 |
+
max_idx = int(input_pos.max())
|
| 58 |
+
|
| 59 |
+
q, k, v = map(lambda x: x[..., :max_idx, :], (q, *kv_cache))
|
| 60 |
+
|
| 61 |
+
mask = attn_mask[..., :max_idx]
|
| 62 |
+
|
| 63 |
+
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
|
| 64 |
+
|
| 65 |
+
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
| 66 |
+
|
| 67 |
+
attn = self.out_proj(attn)
|
| 68 |
+
|
| 69 |
+
return attn
|
| 70 |
+
|
| 71 |
+
# def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
| 72 |
+
# bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
|
| 73 |
+
|
| 74 |
+
# q, k, v = self.in_proj(x).split(3, axis=-1)
|
| 75 |
+
|
| 76 |
+
# q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
| 77 |
+
|
| 78 |
+
# q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
| 79 |
+
|
| 80 |
+
# kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
|
| 81 |
+
|
| 82 |
+
# assert len(kv_cache) == 3
|
| 83 |
+
# (k_q, k_s, k_b), (v_q, v_s, v_b), (group_size, bits) = kv_cache
|
| 84 |
+
|
| 85 |
+
# k_q, k_s, k_b, v_q, v_s, v_b = map(lambda x: x[..., : int(input_pos.max()), :], (k_q, k_s, k_b, v_q, v_s, v_b))
|
| 86 |
+
|
| 87 |
+
# mask = attn_mask[..., : int(input_pos.max())]
|
| 88 |
+
|
| 89 |
+
# attn = Attention.quantized_scaled_dot_product_attention(
|
| 90 |
+
# q,
|
| 91 |
+
# (k_q, k_s, k_b),
|
| 92 |
+
# (v_q, v_s, v_b),
|
| 93 |
+
# self.scale,
|
| 94 |
+
# mask,
|
| 95 |
+
# group_size,
|
| 96 |
+
# bits,
|
| 97 |
+
# )
|
| 98 |
+
|
| 99 |
+
# attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
| 100 |
+
|
| 101 |
+
# output = self.out_proj(attn)
|
| 102 |
+
|
| 103 |
+
# return output
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class TransformerBlock(TransformerBlockABC):
|
| 107 |
+
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int, *args, **kwds) -> None:
|
| 108 |
+
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length, *args, **kwds)
|
| 109 |
+
|
| 110 |
+
self.attention = Attention(n_head, hidden_dim, max_seq_length, *args, **kwds)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class TransformerDecoder(TransformerDecoderABC):
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
hidden_dim: int,
|
| 117 |
+
n_layer: int,
|
| 118 |
+
n_head: int,
|
| 119 |
+
ffn_dim: int,
|
| 120 |
+
vocab_size: int,
|
| 121 |
+
max_seq_length: int,
|
| 122 |
+
max_batch_size: int,
|
| 123 |
+
*args,
|
| 124 |
+
**kwds,
|
| 125 |
+
) -> None:
|
| 126 |
+
super().__init__(
|
| 127 |
+
hidden_dim,
|
| 128 |
+
n_layer,
|
| 129 |
+
n_head,
|
| 130 |
+
ffn_dim,
|
| 131 |
+
vocab_size,
|
| 132 |
+
max_seq_length,
|
| 133 |
+
max_batch_size,
|
| 134 |
+
*args,
|
| 135 |
+
**kwds,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
self.layers = [
|
| 139 |
+
TransformerBlock(
|
| 140 |
+
n_head,
|
| 141 |
+
ffn_dim,
|
| 142 |
+
hidden_dim,
|
| 143 |
+
max_seq_length,
|
| 144 |
+
*args,
|
| 145 |
+
**kwds,
|
| 146 |
+
)
|
| 147 |
+
for _ in range(n_layer)
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class T2SDecoder(T2SDecoderABC):
|
| 152 |
+
def __init__(
|
| 153 |
+
self,
|
| 154 |
+
config: dict,
|
| 155 |
+
max_seq_length: int = 2000,
|
| 156 |
+
max_batch_size: int = 10,
|
| 157 |
+
) -> None:
|
| 158 |
+
super().__init__(config, max_seq_length, max_batch_size)
|
| 159 |
+
|
| 160 |
+
self.h = TransformerDecoder(
|
| 161 |
+
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
self.kv_class = KVCacheHND
|
| 165 |
+
self.group_size = 32
|
| 166 |
+
self.bits = 8
|
| 167 |
+
self.mode = "affine"
|
| 168 |
+
|
| 169 |
+
def set_mode(self, mode: str):
|
| 170 |
+
assert mode in ["affine", "mxfp4"]
|
| 171 |
+
self.mode = mode
|
| 172 |
+
if self.mode == "mxfp4":
|
| 173 |
+
self.bits = 4
|
| 174 |
+
else:
|
| 175 |
+
self.bits = 8
|
| 176 |
+
|
| 177 |
+
def quantized(self):
|
| 178 |
+
nn.quantize(self, self.group_size, self.bits, mode=self.mode)
|
| 179 |
+
# for layer in self.h.layers:
|
| 180 |
+
# nn.quantize(layer.feed_forward, self.group_size, self.bits)
|
| 181 |
+
# nn.quantize(layer.attention, self.group_size, self.bits)
|
GPT_SoVITS/Accelerate/MLX/backends/mlx_static.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import cast
|
| 4 |
+
|
| 5 |
+
import mlx.core as mx
|
| 6 |
+
|
| 7 |
+
from ..structs_mlx import KVCache, KVCacheQ
|
| 8 |
+
from ..t2s_model_abc import (
|
| 9 |
+
AttentionABC,
|
| 10 |
+
KVCacheHND,
|
| 11 |
+
T2SDecoderABC,
|
| 12 |
+
TransformerBlockABC,
|
| 13 |
+
TransformerDecoderABC,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
Array = mx.array
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Attention(AttentionABC):
|
| 20 |
+
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
|
| 21 |
+
super().__init__(n_head, hidden_dim, max_seq_length)
|
| 22 |
+
self.kc_class = KVCacheHND
|
| 23 |
+
|
| 24 |
+
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
| 25 |
+
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
|
| 26 |
+
|
| 27 |
+
q, k, v = self.in_proj(x).split(3, axis=-1)
|
| 28 |
+
|
| 29 |
+
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
| 30 |
+
|
| 31 |
+
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
| 32 |
+
|
| 33 |
+
kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
|
| 34 |
+
assert len(kv_cache) == 2
|
| 35 |
+
|
| 36 |
+
k, v = kv_cache
|
| 37 |
+
|
| 38 |
+
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=attn_mask)
|
| 39 |
+
|
| 40 |
+
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
| 41 |
+
|
| 42 |
+
attn = self.out_proj(attn)
|
| 43 |
+
|
| 44 |
+
return attn
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class TransformerBlock(TransformerBlockABC):
|
| 48 |
+
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
| 49 |
+
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
| 50 |
+
|
| 51 |
+
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class TransformerDecoder(TransformerDecoderABC):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
hidden_dim: int,
|
| 58 |
+
n_layer: int,
|
| 59 |
+
n_head: int,
|
| 60 |
+
ffn_dim: int,
|
| 61 |
+
vocab_size: int,
|
| 62 |
+
max_seq_length: int,
|
| 63 |
+
max_batch_size: int,
|
| 64 |
+
) -> None:
|
| 65 |
+
super().__init__(
|
| 66 |
+
hidden_dim,
|
| 67 |
+
n_layer,
|
| 68 |
+
n_head,
|
| 69 |
+
ffn_dim,
|
| 70 |
+
vocab_size,
|
| 71 |
+
max_seq_length,
|
| 72 |
+
max_batch_size,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
self.layers = [
|
| 76 |
+
TransformerBlock(
|
| 77 |
+
n_head,
|
| 78 |
+
ffn_dim,
|
| 79 |
+
hidden_dim,
|
| 80 |
+
max_seq_length,
|
| 81 |
+
)
|
| 82 |
+
for _ in range(n_layer)
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class T2SDecoder(T2SDecoderABC):
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
config: dict,
|
| 90 |
+
max_seq_length: int = 2000,
|
| 91 |
+
max_batch_size: int = 10,
|
| 92 |
+
) -> None:
|
| 93 |
+
super().__init__(config, max_seq_length, max_batch_size)
|
| 94 |
+
|
| 95 |
+
self.h = TransformerDecoder(
|
| 96 |
+
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
self.kv_class = KVCacheHND
|
GPT_SoVITS/Accelerate/MLX/backends/mlx_varlen.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import cast
|
| 4 |
+
|
| 5 |
+
import mlx.core as mx
|
| 6 |
+
|
| 7 |
+
from ..structs_mlx import KVCache, KVCacheQ
|
| 8 |
+
from ..t2s_model_abc import (
|
| 9 |
+
AttentionABC,
|
| 10 |
+
KVCacheHND,
|
| 11 |
+
T2SDecoderABC,
|
| 12 |
+
TransformerBlockABC,
|
| 13 |
+
TransformerDecoderABC,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
Array = mx.array
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Attention(AttentionABC):
|
| 20 |
+
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
|
| 21 |
+
super().__init__(n_head, hidden_dim, max_seq_length)
|
| 22 |
+
self.kc_class = KVCacheHND
|
| 23 |
+
|
| 24 |
+
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
| 25 |
+
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
|
| 26 |
+
|
| 27 |
+
q, k, v = self.in_proj(x).split(3, axis=-1)
|
| 28 |
+
|
| 29 |
+
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
| 30 |
+
|
| 31 |
+
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
| 32 |
+
|
| 33 |
+
kv_cache = self.kc_class.update_cache(input_pos, k, v, kv_cache, cache_idx)
|
| 34 |
+
assert len(kv_cache) == 2
|
| 35 |
+
|
| 36 |
+
max_idx = int(input_pos.max())
|
| 37 |
+
|
| 38 |
+
q, k, v = map(lambda x: x[..., :max_idx, :], (q, *kv_cache))
|
| 39 |
+
|
| 40 |
+
mask = attn_mask[..., :max_idx]
|
| 41 |
+
|
| 42 |
+
attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
|
| 43 |
+
|
| 44 |
+
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
| 45 |
+
|
| 46 |
+
attn = self.out_proj(attn)
|
| 47 |
+
|
| 48 |
+
return attn
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class TransformerBlock(TransformerBlockABC):
|
| 52 |
+
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
| 53 |
+
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
| 54 |
+
|
| 55 |
+
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class TransformerDecoder(TransformerDecoderABC):
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
hidden_dim: int,
|
| 62 |
+
n_layer: int,
|
| 63 |
+
n_head: int,
|
| 64 |
+
ffn_dim: int,
|
| 65 |
+
vocab_size: int,
|
| 66 |
+
max_seq_length: int,
|
| 67 |
+
max_batch_size: int,
|
| 68 |
+
) -> None:
|
| 69 |
+
super().__init__(
|
| 70 |
+
hidden_dim,
|
| 71 |
+
n_layer,
|
| 72 |
+
n_head,
|
| 73 |
+
ffn_dim,
|
| 74 |
+
vocab_size,
|
| 75 |
+
max_seq_length,
|
| 76 |
+
max_batch_size,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
self.layers = [
|
| 80 |
+
TransformerBlock(
|
| 81 |
+
n_head,
|
| 82 |
+
ffn_dim,
|
| 83 |
+
hidden_dim,
|
| 84 |
+
max_seq_length,
|
| 85 |
+
)
|
| 86 |
+
for _ in range(n_layer)
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class T2SDecoder(T2SDecoderABC):
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
config: dict,
|
| 94 |
+
max_seq_length: int = 2000,
|
| 95 |
+
max_batch_size: int = 10,
|
| 96 |
+
) -> None:
|
| 97 |
+
super().__init__(config, max_seq_length, max_batch_size)
|
| 98 |
+
|
| 99 |
+
self.h = TransformerDecoder(
|
| 100 |
+
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
self.kv_class = KVCacheHND
|
GPT_SoVITS/Accelerate/MLX/sample_funcs_mlx.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Protocol, cast
|
| 2 |
+
|
| 3 |
+
import mlx.core as mx
|
| 4 |
+
|
| 5 |
+
Array = mx.array
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SampleProtocolMLX(Protocol):
|
| 9 |
+
@staticmethod
|
| 10 |
+
def __call__(
|
| 11 |
+
logits: Array,
|
| 12 |
+
previous_tokens: Array,
|
| 13 |
+
temperature: float,
|
| 14 |
+
top_k: int,
|
| 15 |
+
top_p: float,
|
| 16 |
+
repetition_penalty: float,
|
| 17 |
+
) -> Array: ...
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class sample_naive(SampleProtocolMLX):
|
| 21 |
+
# @partial(mx.compile)
|
| 22 |
+
@staticmethod
|
| 23 |
+
def __call__(
|
| 24 |
+
logits,
|
| 25 |
+
previous_tokens,
|
| 26 |
+
temperature,
|
| 27 |
+
top_k,
|
| 28 |
+
top_p,
|
| 29 |
+
repetition_penalty,
|
| 30 |
+
):
|
| 31 |
+
if temperature <= 1e-5:
|
| 32 |
+
probs = mx.softmax(logits, axis=-1)
|
| 33 |
+
return mx.argmax(probs, axis=-1, keepdims=True).astype(mx.int32)
|
| 34 |
+
|
| 35 |
+
if repetition_penalty != 1.0:
|
| 36 |
+
batch_idx = mx.arange(cast(tuple[int, ...], previous_tokens.shape)[0])
|
| 37 |
+
previous_tokens = previous_tokens.astype(mx.int64)
|
| 38 |
+
selected_logists = logits[batch_idx, previous_tokens]
|
| 39 |
+
selected_logists = mx.where(
|
| 40 |
+
selected_logists < 0, selected_logists * repetition_penalty, selected_logists / repetition_penalty
|
| 41 |
+
)
|
| 42 |
+
logits[batch_idx, previous_tokens] = selected_logists
|
| 43 |
+
|
| 44 |
+
if top_p < 1.0:
|
| 45 |
+
sorted_indices = mx.argsort(-logits, axis=-1)
|
| 46 |
+
sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1)
|
| 47 |
+
cum_probs = mx.cumsum(mx.softmax(sorted_logits, axis=-1), axis=-1)
|
| 48 |
+
sorted_indices_to_remove = cum_probs > top_p
|
| 49 |
+
sorted_indices_to_remove[:, -1] = False
|
| 50 |
+
indices_to_remove = mx.zeros_like(logits).astype(mx.bool_)
|
| 51 |
+
batch_indices = mx.arange(cast(tuple[int, ...], logits.shape)[0])[:, None]
|
| 52 |
+
indices_to_remove[batch_indices, sorted_indices] = sorted_indices_to_remove
|
| 53 |
+
logits = mx.where(indices_to_remove, -mx.inf, logits)
|
| 54 |
+
|
| 55 |
+
if temperature < 1.0:
|
| 56 |
+
logits = logits / temperature
|
| 57 |
+
|
| 58 |
+
v = mx.topk(logits, top_k)
|
| 59 |
+
pivot = mx.expand_dims(v[:, 0], -1)
|
| 60 |
+
logits = mx.where(logits < pivot, -mx.inf, logits)
|
| 61 |
+
|
| 62 |
+
gumbel_noise = mx.random.gumbel(shape=cast(tuple[int, ...], logits.shape), dtype=logits.dtype)
|
| 63 |
+
idx_next = mx.argmax(logits + gumbel_noise, axis=-1, keepdims=True).astype(mx.int32)
|
| 64 |
+
|
| 65 |
+
return idx_next
|
GPT_SoVITS/Accelerate/MLX/structs_mlx.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Modified From https://github.com/XXXXRT666/GPT-SoVITS
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import List, MutableSequence, Protocol, TypeAlias, cast
|
| 9 |
+
|
| 10 |
+
import mlx.core as mx
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from ..PyTorch.structs import T2SRequest
|
| 14 |
+
from .sample_funcs_mlx import SampleProtocolMLX, sample_naive
|
| 15 |
+
|
| 16 |
+
Tensor = torch.Tensor
|
| 17 |
+
Array = mx.array
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass(slots=True)
|
| 21 |
+
class T2SRequestMLX:
|
| 22 |
+
x: List[Array]
|
| 23 |
+
x_lens: Array
|
| 24 |
+
prompts: Array
|
| 25 |
+
bert_feature: List[Array]
|
| 26 |
+
valid_length: int
|
| 27 |
+
top_k: int = 5
|
| 28 |
+
top_p: float = 1
|
| 29 |
+
early_stop_num: int = -1
|
| 30 |
+
temperature: float = 1.0
|
| 31 |
+
repetition_penalty: float = 1.35
|
| 32 |
+
|
| 33 |
+
@classmethod
|
| 34 |
+
def from_torch(cls, request: T2SRequest) -> T2SRequestMLX:
|
| 35 |
+
x = list(map(lambda tensor: mx.array(tensor.cpu()), request.x))
|
| 36 |
+
x_lens = mx.array(request.x_lens.cpu())
|
| 37 |
+
prompts = mx.array(request.prompts.cpu())
|
| 38 |
+
bert_feature = list(map(lambda tensor: mx.array(tensor.cpu()), request.bert_feature))
|
| 39 |
+
|
| 40 |
+
return cls(
|
| 41 |
+
x,
|
| 42 |
+
x_lens,
|
| 43 |
+
prompts,
|
| 44 |
+
bert_feature,
|
| 45 |
+
request.valid_length,
|
| 46 |
+
request.top_k,
|
| 47 |
+
request.top_p,
|
| 48 |
+
request.early_stop_num,
|
| 49 |
+
request.temperature,
|
| 50 |
+
request.repetition_penalty,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
KVCache: TypeAlias = tuple[Array, Array]
|
| 55 |
+
KVCacheQ: TypeAlias = tuple[tuple[Array, Array, Array], tuple[Array, Array, Array], tuple[int, int]]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class KVCacheProtocol(Protocol):
|
| 59 |
+
@staticmethod
|
| 60 |
+
def empty(kv_cache: KVCache | KVCacheQ) -> None: ...
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
def update_cache(
|
| 64 |
+
input_pos: Array, k_val: Array, v_val: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array
|
| 65 |
+
) -> KVCache | KVCacheQ: ...
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def prefill_kv(k_val: Array, v_val: Array, kv_cache: KVCache | KVCacheQ) -> None: ...
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
def init_cache(
|
| 72 |
+
batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype, *args, **kwds
|
| 73 |
+
) -> KVCache | KVCacheQ: ...
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class T2SDecoderProtocol(Protocol):
|
| 77 |
+
max_seq_length: int
|
| 78 |
+
EOS: int
|
| 79 |
+
n_head: int
|
| 80 |
+
|
| 81 |
+
def embed(self, x: list[Array], y: Array, bert_features: list[Array]) -> Array: ...
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class T2SSessionMLX:
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
decoder: T2SDecoderProtocol,
|
| 88 |
+
request_torch: T2SRequest,
|
| 89 |
+
sample_func: type[SampleProtocolMLX] = sample_naive,
|
| 90 |
+
device: mx.Device = mx.Device(mx.cpu),
|
| 91 |
+
dtype: mx.Dtype = mx.float32,
|
| 92 |
+
):
|
| 93 |
+
with mx.stream(device):
|
| 94 |
+
request = T2SRequestMLX.from_torch(request_torch)
|
| 95 |
+
|
| 96 |
+
self.decoder = decoder
|
| 97 |
+
self.request = request
|
| 98 |
+
self.device = device
|
| 99 |
+
self.dtype = dtype
|
| 100 |
+
|
| 101 |
+
bsz = len(request.x)
|
| 102 |
+
y_len: int = cast(tuple[int, ...], request.prompts.shape)[-1]
|
| 103 |
+
self.bsz = bsz
|
| 104 |
+
self.y_len = y_len
|
| 105 |
+
|
| 106 |
+
# Cache
|
| 107 |
+
self.kv_cache: MutableSequence[KVCache | KVCacheQ]
|
| 108 |
+
self.sample = sample_func()
|
| 109 |
+
|
| 110 |
+
# Forward args
|
| 111 |
+
self.x = [i.astype(mx.int32) for i in request.x]
|
| 112 |
+
self.x_lens = request.x_lens.astype(mx.int32)
|
| 113 |
+
self.y = mx.zeros((bsz, decoder.max_seq_length)).astype(mx.int32)
|
| 114 |
+
self.y[:, : cast(tuple[int, ...], request.prompts.shape)[-1]] = request.prompts.astype(mx.int32)
|
| 115 |
+
self.bert_feature = [i.astype(dtype) for i in request.bert_feature]
|
| 116 |
+
|
| 117 |
+
self.prefill_len = self.x_lens + cast(tuple[int, ...], request.prompts.shape)[1]
|
| 118 |
+
|
| 119 |
+
self.input_pos = mx.zeros_like(self.prefill_len)
|
| 120 |
+
self.input_pos += self.prefill_len
|
| 121 |
+
|
| 122 |
+
# EOS
|
| 123 |
+
self.completed = mx.array([False] * len(self.x)).astype(mx.bool_)
|
| 124 |
+
self.y_results: List[Array] = [None] * len(self.x) # type: ignore
|
| 125 |
+
|
| 126 |
+
self.xy_pos = decoder.embed(self.x, request.prompts, self.bert_feature)
|
| 127 |
+
|
| 128 |
+
max_len = int(self.prefill_len.max(-1))
|
| 129 |
+
attn_mask = mx.zeros(shape=(bsz, max_len, max_len), dtype=mx.bool_)
|
| 130 |
+
|
| 131 |
+
for bs in range(bsz):
|
| 132 |
+
pos = int(self.x_lens[bs])
|
| 133 |
+
seq_len = pos + y_len
|
| 134 |
+
|
| 135 |
+
attn_mask[bs, :seq_len, :pos] = True
|
| 136 |
+
|
| 137 |
+
ar_mask = ~mx.triu(
|
| 138 |
+
x=mx.ones(
|
| 139 |
+
shape=(
|
| 140 |
+
y_len,
|
| 141 |
+
y_len,
|
| 142 |
+
),
|
| 143 |
+
dtype=mx.bool_,
|
| 144 |
+
),
|
| 145 |
+
k=1,
|
| 146 |
+
)
|
| 147 |
+
attn_mask[bs, pos:seq_len, pos:seq_len] = ar_mask
|
| 148 |
+
|
| 149 |
+
attn_mask = mx.repeat(mx.expand_dims(attn_mask, 1), decoder.n_head, 1)
|
| 150 |
+
self.attn_mask = attn_mask
|
| 151 |
+
|
| 152 |
+
mx.eval(self.attn_mask)
|
GPT_SoVITS/Accelerate/MLX/t2s_engine_mlx.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import os
|
| 3 |
+
import time
|
| 4 |
+
import traceback
|
| 5 |
+
from typing import cast
|
| 6 |
+
|
| 7 |
+
import mlx.core as mx
|
| 8 |
+
import torch
|
| 9 |
+
from rich.progress import BarColumn, Progress, TextColumn
|
| 10 |
+
|
| 11 |
+
from ..logger import SpeedColumnToken, console, logger
|
| 12 |
+
from ..PyTorch.structs import T2SEngineProtocol, T2SRequest, T2SResult
|
| 13 |
+
from .backends import mlx_quantized, mlx_static, mlx_varlen
|
| 14 |
+
from .structs_mlx import T2SSessionMLX
|
| 15 |
+
from .t2s_model_abc import T2SDecoderABC
|
| 16 |
+
|
| 17 |
+
Array = mx.array
|
| 18 |
+
Tensor = torch.Tensor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class T2SEngine(T2SEngineProtocol):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
decoder_model: T2SDecoderABC,
|
| 25 |
+
device: mx.Device | str = mx.Device(mx.cpu),
|
| 26 |
+
dtype: torch.dtype | mx.Dtype = torch.float32,
|
| 27 |
+
) -> None:
|
| 28 |
+
if isinstance(device, str):
|
| 29 |
+
match device:
|
| 30 |
+
case "mx.cpu":
|
| 31 |
+
device = mx.Device(mx.cpu)
|
| 32 |
+
case "mx.gpu":
|
| 33 |
+
device = mx.Device(mx.gpu)
|
| 34 |
+
|
| 35 |
+
match dtype:
|
| 36 |
+
case torch.float32:
|
| 37 |
+
dtype = mx.float32
|
| 38 |
+
case torch.float16:
|
| 39 |
+
dtype = mx.float16
|
| 40 |
+
case torch.bfloat16:
|
| 41 |
+
dtype = mx.bfloat16
|
| 42 |
+
|
| 43 |
+
device = cast(mx.Device, device)
|
| 44 |
+
dtype = cast(mx.Dtype, dtype)
|
| 45 |
+
|
| 46 |
+
assert device.type.value in {0, 1}
|
| 47 |
+
assert dtype in {mx.float16, mx.bfloat16, mx.float32}
|
| 48 |
+
|
| 49 |
+
self.device = device
|
| 50 |
+
self.dtype = dtype
|
| 51 |
+
|
| 52 |
+
mx.set_default_device(device)
|
| 53 |
+
decoder_model.set_dtype(self.dtype)
|
| 54 |
+
|
| 55 |
+
self.decoder_model: T2SDecoderABC = decoder_model
|
| 56 |
+
self.decoder_model.compile()
|
| 57 |
+
|
| 58 |
+
def _handle_request(self, request: T2SRequest):
|
| 59 |
+
decoder = self.decoder_model
|
| 60 |
+
session = T2SSessionMLX(decoder, request, device=self.device, dtype=self.dtype)
|
| 61 |
+
batch_idx = mx.arange(session.bsz)
|
| 62 |
+
|
| 63 |
+
t1 = 0.0
|
| 64 |
+
infer_speed = 0.0
|
| 65 |
+
infer_time = 0.0
|
| 66 |
+
|
| 67 |
+
with (
|
| 68 |
+
mx.stream(session.device),
|
| 69 |
+
Progress(
|
| 70 |
+
TextColumn("[cyan]{task.description}"),
|
| 71 |
+
BarColumn(),
|
| 72 |
+
TextColumn("{task.completed}/{task.total}"),
|
| 73 |
+
SpeedColumnToken(show_speed=True),
|
| 74 |
+
console=console,
|
| 75 |
+
transient=True,
|
| 76 |
+
) as progress,
|
| 77 |
+
):
|
| 78 |
+
max_token = min(2000 - int(session.input_pos.max()), 1500)
|
| 79 |
+
|
| 80 |
+
task = progress.add_task("T2S Decoding", total=max_token)
|
| 81 |
+
for idx in range(1500):
|
| 82 |
+
progress.update(task, advance=1)
|
| 83 |
+
if idx == 0:
|
| 84 |
+
session.kv_cache = decoder.init_cache(session.bsz)
|
| 85 |
+
xy_dec = decoder.h.prefill(
|
| 86 |
+
session.xy_pos,
|
| 87 |
+
session.attn_mask,
|
| 88 |
+
session.kv_cache,
|
| 89 |
+
) # bs, seq_len, embed_dim
|
| 90 |
+
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
|
| 91 |
+
else:
|
| 92 |
+
args, kwds = decoder.pre_forward(session)
|
| 93 |
+
xy_dec = decoder.h(
|
| 94 |
+
session.input_pos,
|
| 95 |
+
session.xy_pos,
|
| 96 |
+
session.kv_cache,
|
| 97 |
+
batch_idx,
|
| 98 |
+
*args,
|
| 99 |
+
**kwds,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
decoder.post_forward(idx, session)
|
| 103 |
+
logits = decoder.ar_predict_layer(xy_dec[:, -1])
|
| 104 |
+
session.input_pos += 1
|
| 105 |
+
|
| 106 |
+
if idx == 0:
|
| 107 |
+
logits[:, -1] = -mx.inf
|
| 108 |
+
|
| 109 |
+
samples = session.sample(
|
| 110 |
+
logits=logits,
|
| 111 |
+
previous_tokens=session.y[:, : session.y_len + idx],
|
| 112 |
+
top_k=request.top_k,
|
| 113 |
+
top_p=request.top_p,
|
| 114 |
+
repetition_penalty=request.repetition_penalty,
|
| 115 |
+
temperature=request.temperature,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
session.y[batch_idx, session.y_len + idx] = samples
|
| 119 |
+
|
| 120 |
+
argmax_token = mx.argmax(logits, axis=-1)
|
| 121 |
+
sample_token = samples.squeeze(1)
|
| 122 |
+
EOS_mask = (cast(Array, argmax_token == decoder.EOS)) | (sample_token == decoder.EOS)
|
| 123 |
+
|
| 124 |
+
newly_done_mask = EOS_mask & (~session.completed)
|
| 125 |
+
newly_done_indices = mx.where(newly_done_mask, batch_idx, -1)
|
| 126 |
+
pos = mx.where(newly_done_indices != -1, batch_idx, session.bsz)
|
| 127 |
+
pos_sorted = mx.sort(pos, axis=0)
|
| 128 |
+
valid_count = session.bsz - mx.sum(cast(Array, pos_sorted == session.bsz))
|
| 129 |
+
pos_final = pos_sorted[: int(valid_count)]
|
| 130 |
+
newly_done_indices = mx.expand_dims(newly_done_indices[pos_final], 0)
|
| 131 |
+
|
| 132 |
+
if newly_done_indices.size > 0:
|
| 133 |
+
for i in newly_done_indices:
|
| 134 |
+
session.y_results[int(i)] = session.y[i, session.y_len : session.y_len + idx]
|
| 135 |
+
session.completed[newly_done_indices] = True
|
| 136 |
+
|
| 137 |
+
if mx.all(session.completed).item():
|
| 138 |
+
if session.y[:, session.y_len :].sum() == 0:
|
| 139 |
+
session.y_results = [mx.array([0]) for _ in range(session.bsz)]
|
| 140 |
+
logger.error("Bad Zero Prediction")
|
| 141 |
+
else:
|
| 142 |
+
logger.info(
|
| 143 |
+
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[cast(tuple[int, ...], i.shape)[-1] for i in session.y_results].__str__().strip('[]')}"
|
| 144 |
+
)
|
| 145 |
+
logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
|
| 146 |
+
infer_time = time.perf_counter() - t1
|
| 147 |
+
infer_speed = (idx - 1) / infer_time
|
| 148 |
+
break
|
| 149 |
+
|
| 150 |
+
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
|
| 151 |
+
for j in range(session.bsz):
|
| 152 |
+
if not session.completed[j].item():
|
| 153 |
+
session.y_results[j] = session.y[[j], session.y_len : session.y_len + 1499]
|
| 154 |
+
session.completed[j] = True
|
| 155 |
+
logger.error("Bad Full Prediction")
|
| 156 |
+
logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
|
| 157 |
+
infer_time = time.perf_counter() - t1
|
| 158 |
+
infer_speed = (idx - 1) / infer_time
|
| 159 |
+
break
|
| 160 |
+
|
| 161 |
+
y_emb = decoder.ar_audio_embedding(samples)
|
| 162 |
+
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
|
| 163 |
+
mx.eval(session.xy_pos, session.y)
|
| 164 |
+
|
| 165 |
+
if idx == 1:
|
| 166 |
+
t1 = time.perf_counter()
|
| 167 |
+
|
| 168 |
+
if idx % 100 == 0:
|
| 169 |
+
mx.clear_cache()
|
| 170 |
+
|
| 171 |
+
match session.device:
|
| 172 |
+
case mx.gpu:
|
| 173 |
+
mx.clear_cache()
|
| 174 |
+
case mx.cpu:
|
| 175 |
+
gc.collect()
|
| 176 |
+
|
| 177 |
+
result_mlx = session.y_results[: request.valid_length]
|
| 178 |
+
mx.eval(result_mlx)
|
| 179 |
+
result = [torch.tensor(k) for k in result_mlx]
|
| 180 |
+
return result, infer_speed, infer_time
|
| 181 |
+
|
| 182 |
+
def generate(self, request: T2SRequest):
|
| 183 |
+
try:
|
| 184 |
+
result, infer_speed, infer_time = self._handle_request(request)
|
| 185 |
+
t2s_result = T2SResult(result=result, infer_speed=(infer_speed, infer_time), status="Success")
|
| 186 |
+
except Exception as e:
|
| 187 |
+
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
|
| 188 |
+
return t2s_result
|
| 189 |
+
|
| 190 |
+
@staticmethod
|
| 191 |
+
def replace_key(state_dict: dict[str, Tensor]):
|
| 192 |
+
state_dict_mlx: list[tuple[str, Array]] = []
|
| 193 |
+
for key, value in state_dict.items():
|
| 194 |
+
key = (
|
| 195 |
+
key.replace("model.", "")
|
| 196 |
+
.replace("in_proj_", "in_proj.")
|
| 197 |
+
.replace("self_attn", "attention")
|
| 198 |
+
.replace("linear", "feed_forward.linear")
|
| 199 |
+
.replace("norm1", "attention_norm")
|
| 200 |
+
.replace("norm2", "ffn_norm")
|
| 201 |
+
)
|
| 202 |
+
value_mlx = mx.array(value)
|
| 203 |
+
state_dict_mlx.append((key, value_mlx))
|
| 204 |
+
return state_dict_mlx
|
| 205 |
+
|
| 206 |
+
@staticmethod
|
| 207 |
+
def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "MLX-Varlen"):
|
| 208 |
+
logger.info(f"Loading Text2Semantic Weights from {weights_path} with {backend} Backend")
|
| 209 |
+
dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
|
| 210 |
+
config = dict_s1["config"]
|
| 211 |
+
match backend:
|
| 212 |
+
case "MLX-Varlen":
|
| 213 |
+
decoder_cls: type[T2SDecoderABC] = mlx_varlen.T2SDecoder
|
| 214 |
+
case "MLX-Static":
|
| 215 |
+
decoder_cls = mlx_static.T2SDecoder
|
| 216 |
+
case "MLX-Quantized-Affine" | "MLX-Quantized-MXFP4":
|
| 217 |
+
decoder_cls = mlx_quantized.T2SDecoder
|
| 218 |
+
case _:
|
| 219 |
+
raise RuntimeError(f"Backend {backend} Not Found")
|
| 220 |
+
|
| 221 |
+
decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=max_batch_size)
|
| 222 |
+
state_dict = dict_s1["weight"]
|
| 223 |
+
state_dict_mlx = T2SEngine.replace_key(state_dict)
|
| 224 |
+
decoder.load_weights(state_dict_mlx)
|
| 225 |
+
decoder.eval()
|
| 226 |
+
mx.eval(decoder)
|
| 227 |
+
|
| 228 |
+
if "Quantized" in backend and isinstance(decoder, mlx_quantized.T2SDecoder):
|
| 229 |
+
if backend == "MLX-Quantized-Affine":
|
| 230 |
+
decoder.set_mode("affine")
|
| 231 |
+
elif backend == "MLX-Quantized-MXFP4":
|
| 232 |
+
decoder.set_mode("mxfp4")
|
| 233 |
+
else:
|
| 234 |
+
raise RuntimeError(f"Quantized Backend {backend} Not Supported")
|
| 235 |
+
decoder.quantized()
|
| 236 |
+
mx.eval(decoder)
|
| 237 |
+
|
| 238 |
+
return decoder
|
GPT_SoVITS/Accelerate/MLX/t2s_model_abc.py
ADDED
|
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from typing import MutableSequence, cast
|
| 6 |
+
|
| 7 |
+
import mlx.core as mx
|
| 8 |
+
import mlx.nn as nn
|
| 9 |
+
|
| 10 |
+
from .structs_mlx import KVCache, KVCacheProtocol, KVCacheQ, T2SDecoderProtocol, T2SSessionMLX
|
| 11 |
+
|
| 12 |
+
Array = mx.array
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TokenEmbedding(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
embedding_dim: int,
|
| 19 |
+
vocab_size: int,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
self.vocab_size = vocab_size
|
| 24 |
+
self.embedding_dim = embedding_dim
|
| 25 |
+
|
| 26 |
+
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
|
| 27 |
+
|
| 28 |
+
@property
|
| 29 |
+
def weight(self):
|
| 30 |
+
return self.word_embeddings.weight
|
| 31 |
+
|
| 32 |
+
def embedding(self, index: int):
|
| 33 |
+
return self.word_embeddings.weight[index : index + 1]
|
| 34 |
+
|
| 35 |
+
def __call__(self, x: Array):
|
| 36 |
+
x = self.word_embeddings(x)
|
| 37 |
+
return x
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class SinePositionalEmbedding(nn.Module):
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
embedding_dim: int,
|
| 44 |
+
scale: bool = False,
|
| 45 |
+
max_batch_size: int = 10,
|
| 46 |
+
max_seq_len: int = 2000,
|
| 47 |
+
):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.embedding_dim = embedding_dim
|
| 50 |
+
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
|
| 51 |
+
self.alpha = mx.ones(1)
|
| 52 |
+
self.max_batch_size = max_batch_size
|
| 53 |
+
self.max_seq_len = max_seq_len
|
| 54 |
+
|
| 55 |
+
self.reverse = False
|
| 56 |
+
self._pe = mx.zeros((max_batch_size, max_seq_len, embedding_dim))
|
| 57 |
+
self.compute_pe()
|
| 58 |
+
|
| 59 |
+
def compute_pe(self):
|
| 60 |
+
"""Reset the positional encodings."""
|
| 61 |
+
|
| 62 |
+
if self.reverse:
|
| 63 |
+
position = mx.expand_dims(mx.arange(self.max_seq_len - 1, -1, -1.0), axis=1)
|
| 64 |
+
else:
|
| 65 |
+
position = mx.expand_dims(mx.arange(self.max_seq_len), axis=1)
|
| 66 |
+
div_term = mx.exp(
|
| 67 |
+
mx.arange(
|
| 68 |
+
0,
|
| 69 |
+
self.embedding_dim,
|
| 70 |
+
2,
|
| 71 |
+
)
|
| 72 |
+
* -(math.log(10000.0) / self.embedding_dim)
|
| 73 |
+
)
|
| 74 |
+
pe = self._pe
|
| 75 |
+
pe[:, :, 0::2] = mx.sin(position * div_term)
|
| 76 |
+
pe[:, :, 1::2] = mx.cos(position * div_term)
|
| 77 |
+
|
| 78 |
+
def __call__(self, input_pos: Array, x: Array):
|
| 79 |
+
"""
|
| 80 |
+
Args:
|
| 81 |
+
input_pos (Array): [batch_size, ]
|
| 82 |
+
x (Array): [batch_size, 1, embed_dim]
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
embedded_x (Array): [batch_size, 1, embed_dim]
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
batch_size = cast(tuple[int, ...], x.shape)[0]
|
| 89 |
+
pe_values = self._pe[mx.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
|
| 90 |
+
|
| 91 |
+
return x * self.x_scale + self.alpha * mx.expand_dims(pe_values, 1) # (batch_size, 1, embed_dim)
|
| 92 |
+
|
| 93 |
+
def prefill(self, x: Array):
|
| 94 |
+
"""
|
| 95 |
+
Args:
|
| 96 |
+
x (Array): [batch_size, seq_len, embed_dim]
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
embedded_x (Array): [batch_size, seq_len, embed_dim]
|
| 100 |
+
"""
|
| 101 |
+
pe_values = self._pe[:, : cast(tuple[int, ...], x.shape)[-2]]
|
| 102 |
+
return x * self.x_scale + self.alpha * pe_values
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class KVCacheHND(KVCacheProtocol):
|
| 106 |
+
@staticmethod
|
| 107 |
+
def empty(kv_cache):
|
| 108 |
+
assert len(kv_cache) == 2
|
| 109 |
+
k_cache, v_cache = kv_cache
|
| 110 |
+
|
| 111 |
+
k_cache[:] = 0
|
| 112 |
+
v_cache[:] = 0
|
| 113 |
+
|
| 114 |
+
@staticmethod
|
| 115 |
+
def update_cache(input_pos, k_val, v_val, kv_cache, cache_idx):
|
| 116 |
+
# input_pos: [B, ], k_val: [B, H, 1, D]
|
| 117 |
+
assert len(kv_cache) == 2
|
| 118 |
+
k_out, v_out = kv_cache
|
| 119 |
+
ip0 = input_pos - 1
|
| 120 |
+
|
| 121 |
+
k_out[cache_idx, :, ip0, None] = k_val
|
| 122 |
+
v_out[cache_idx, :, ip0, None] = v_val
|
| 123 |
+
|
| 124 |
+
return k_out, v_out
|
| 125 |
+
|
| 126 |
+
@staticmethod
|
| 127 |
+
def prefill_kv(k_val, v_val, kv_cache):
|
| 128 |
+
# k_val: [B, S, H, D]
|
| 129 |
+
assert len(kv_cache) == 2
|
| 130 |
+
k_cache, v_cache = kv_cache
|
| 131 |
+
|
| 132 |
+
k_cache[..., : cast(tuple[int, ...], k_val.shape)[1], :] = k_val.swapaxes(1, 2)
|
| 133 |
+
v_cache[..., : cast(tuple[int, ...], v_val.shape)[1], :] = v_val.swapaxes(1, 2)
|
| 134 |
+
|
| 135 |
+
@staticmethod
|
| 136 |
+
def init_cache(batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: mx.Dtype) -> KVCache:
|
| 137 |
+
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
|
| 138 |
+
|
| 139 |
+
return (mx.zeros(cache_shape, dtype=dtype), mx.zeros(cache_shape, dtype=dtype))
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class KVCacheHNDQuantized(KVCacheProtocol):
|
| 143 |
+
@staticmethod
|
| 144 |
+
def _el_per_int(bits: int) -> int:
|
| 145 |
+
return 32 // bits
|
| 146 |
+
|
| 147 |
+
@staticmethod
|
| 148 |
+
def _packed_dim(head_dim: int, bits: int = 8) -> int:
|
| 149 |
+
el_per_int = KVCacheHNDQuantized._el_per_int(bits)
|
| 150 |
+
if head_dim % el_per_int != 0:
|
| 151 |
+
raise ValueError(f"{head_dim=} is not divisible by {el_per_int=} ({bits=})")
|
| 152 |
+
return head_dim // el_per_int
|
| 153 |
+
|
| 154 |
+
@staticmethod
|
| 155 |
+
def _group_count(head_dim: int, group_size: int = 32) -> int:
|
| 156 |
+
assert group_size in {32, 64, 128}
|
| 157 |
+
if head_dim % group_size != 0:
|
| 158 |
+
raise ValueError(f"{head_dim} is not divisible by {group_size=}")
|
| 159 |
+
return head_dim // group_size
|
| 160 |
+
|
| 161 |
+
@staticmethod
|
| 162 |
+
def empty(kv_cache) -> None:
|
| 163 |
+
assert len(kv_cache) == 3
|
| 164 |
+
(k_q, k_s, k_b), (v_q, v_s, v_b), (_, __) = kv_cache
|
| 165 |
+
|
| 166 |
+
k_q[:] = 0
|
| 167 |
+
k_s[:] = 0
|
| 168 |
+
k_b[:] = 0
|
| 169 |
+
v_q[:] = 0
|
| 170 |
+
v_s[:] = 0
|
| 171 |
+
v_b[:] = 0
|
| 172 |
+
|
| 173 |
+
@staticmethod
|
| 174 |
+
def update_cache(
|
| 175 |
+
input_pos,
|
| 176 |
+
k_val,
|
| 177 |
+
v_val,
|
| 178 |
+
kv_cache,
|
| 179 |
+
cache_idx,
|
| 180 |
+
):
|
| 181 |
+
# input_pos: [B, ], k_val: [B, H, 1, D]
|
| 182 |
+
|
| 183 |
+
assert len(kv_cache) == 3
|
| 184 |
+
(k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache
|
| 185 |
+
|
| 186 |
+
k_q, k_s, k_b = mx.quantize(k_val, group_size=group_size, bits=bits)
|
| 187 |
+
v_q, v_s, v_b = mx.quantize(v_val, group_size=group_size, bits=bits)
|
| 188 |
+
|
| 189 |
+
ip0 = input_pos - 1
|
| 190 |
+
|
| 191 |
+
k_q_out[cache_idx, :, ip0, None] = k_q
|
| 192 |
+
k_s_out[cache_idx, :, ip0, None] = k_s
|
| 193 |
+
k_b_out[cache_idx, :, ip0, None] = k_b
|
| 194 |
+
|
| 195 |
+
v_q_out[cache_idx, :, ip0, None] = v_q
|
| 196 |
+
v_s_out[cache_idx, :, ip0, None] = v_s
|
| 197 |
+
v_b_out[cache_idx, :, ip0, None] = v_b
|
| 198 |
+
|
| 199 |
+
return (k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits)
|
| 200 |
+
|
| 201 |
+
@staticmethod
|
| 202 |
+
def prefill_kv(
|
| 203 |
+
k_val,
|
| 204 |
+
v_val,
|
| 205 |
+
kv_cache,
|
| 206 |
+
) -> None:
|
| 207 |
+
assert len(kv_cache) == 3
|
| 208 |
+
(k_q_out, k_s_out, k_b_out), (v_q_out, v_s_out, v_b_out), (group_size, bits) = kv_cache
|
| 209 |
+
|
| 210 |
+
S = cast(tuple[int, ...], k_val.shape)[1]
|
| 211 |
+
|
| 212 |
+
k_sw = k_val.swapaxes(1, 2)
|
| 213 |
+
v_sw = v_val.swapaxes(1, 2)
|
| 214 |
+
|
| 215 |
+
k_q, k_s, k_b = mx.quantize(k_sw, group_size=group_size, bits=bits)
|
| 216 |
+
v_q, v_s, v_b = mx.quantize(v_sw, group_size=group_size, bits=bits)
|
| 217 |
+
|
| 218 |
+
k_q_out[..., :S, :] = k_q
|
| 219 |
+
k_s_out[..., :S, :] = k_s
|
| 220 |
+
k_b_out[..., :S, :] = k_b
|
| 221 |
+
|
| 222 |
+
v_q_out[..., :S, :] = v_q
|
| 223 |
+
v_s_out[..., :S, :] = v_s
|
| 224 |
+
v_b_out[..., :S, :] = v_b
|
| 225 |
+
|
| 226 |
+
@staticmethod
|
| 227 |
+
def init_cache(
|
| 228 |
+
batch_size: int,
|
| 229 |
+
max_seq_length: int,
|
| 230 |
+
n_heads: int,
|
| 231 |
+
head_dim: int,
|
| 232 |
+
dtype: mx.Dtype,
|
| 233 |
+
*,
|
| 234 |
+
group_size: int = 32,
|
| 235 |
+
bits: int = 8,
|
| 236 |
+
) -> KVCacheQ:
|
| 237 |
+
packed_dim = KVCacheHNDQuantized._packed_dim(head_dim, bits=bits)
|
| 238 |
+
group_cnt = KVCacheHNDQuantized._group_count(head_dim, group_size=group_size)
|
| 239 |
+
|
| 240 |
+
packed_shape = (batch_size, n_heads, max_seq_length, packed_dim)
|
| 241 |
+
group_shape = (batch_size, n_heads, max_seq_length, group_cnt)
|
| 242 |
+
|
| 243 |
+
k_q = mx.zeros(packed_shape, dtype=mx.uint32)
|
| 244 |
+
k_s = mx.zeros(group_shape, dtype=dtype)
|
| 245 |
+
k_b = mx.zeros(group_shape, dtype=dtype)
|
| 246 |
+
|
| 247 |
+
v_q = mx.zeros(packed_shape, dtype=mx.uint32)
|
| 248 |
+
v_s = mx.zeros(group_shape, dtype=dtype)
|
| 249 |
+
v_b = mx.zeros(group_shape, dtype=dtype)
|
| 250 |
+
|
| 251 |
+
return (k_q, k_s, k_b), (v_q, v_s, v_b), (group_size, bits)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class AttentionABC(ABC, nn.Module):
|
| 255 |
+
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int, *args, **kwds):
|
| 256 |
+
super().__init__()
|
| 257 |
+
|
| 258 |
+
self.n_head = n_head
|
| 259 |
+
self.hidden_dim = hidden_dim
|
| 260 |
+
assert hidden_dim % n_head == 0
|
| 261 |
+
self.head_dim = hidden_dim // n_head
|
| 262 |
+
|
| 263 |
+
self.max_seq_length = max_seq_length
|
| 264 |
+
|
| 265 |
+
# key, query, value projections for all heads, but in a batch
|
| 266 |
+
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
| 267 |
+
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 268 |
+
|
| 269 |
+
self.scale = 1 / math.sqrt(self.head_dim)
|
| 270 |
+
|
| 271 |
+
self.kc_class: KVCacheProtocol
|
| 272 |
+
|
| 273 |
+
@abstractmethod
|
| 274 |
+
def __call__(
|
| 275 |
+
self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array
|
| 276 |
+
) -> Array: ...
|
| 277 |
+
|
| 278 |
+
def prefill(self, x: Array, kv_cache: KVCache | KVCacheQ, attn_mask: Array):
|
| 279 |
+
bsz, seqlen, _ = cast(tuple[int, ...], x.shape)
|
| 280 |
+
|
| 281 |
+
q, k, v = self.in_proj(x).split(3, axis=-1)
|
| 282 |
+
|
| 283 |
+
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
| 284 |
+
|
| 285 |
+
self.kc_class.prefill_kv(k, v, kv_cache)
|
| 286 |
+
|
| 287 |
+
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
| 288 |
+
|
| 289 |
+
attn = mx.fast.scaled_dot_product_attention(q, k, v, mask=attn_mask, scale=self.scale)
|
| 290 |
+
|
| 291 |
+
attn = mx.nan_to_num(attn)
|
| 292 |
+
|
| 293 |
+
attn = attn.swapaxes(1, 2).reshape(1, -1, self.hidden_dim)
|
| 294 |
+
|
| 295 |
+
output = self.out_proj(attn)
|
| 296 |
+
|
| 297 |
+
return output
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class FeedForward(nn.Module):
|
| 301 |
+
def __init__(self, dim: int, hidden_dim: int) -> None:
|
| 302 |
+
super().__init__()
|
| 303 |
+
|
| 304 |
+
self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
|
| 305 |
+
self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
|
| 306 |
+
|
| 307 |
+
def __call__(self, x: Array):
|
| 308 |
+
return self.linear2(nn.relu(self.linear1(x)))
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class TransformerBlockABC(nn.Module):
|
| 312 |
+
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int, *args, **kwds) -> None:
|
| 313 |
+
super().__init__()
|
| 314 |
+
|
| 315 |
+
self.hidden_dim = hidden_dim
|
| 316 |
+
self.max_seq_length = max_seq_length
|
| 317 |
+
|
| 318 |
+
self.attention: AttentionABC
|
| 319 |
+
|
| 320 |
+
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
| 321 |
+
self.attention_norm = nn.LayerNorm(self.hidden_dim)
|
| 322 |
+
self.ffn_norm = nn.LayerNorm(self.hidden_dim)
|
| 323 |
+
|
| 324 |
+
def __call__(self, x: Array, input_pos: Array, kv_cache: KVCache | KVCacheQ, cache_idx: Array, attn_mask: Array):
|
| 325 |
+
h = self.attention_norm(
|
| 326 |
+
x
|
| 327 |
+
+ self.attention(
|
| 328 |
+
x,
|
| 329 |
+
input_pos,
|
| 330 |
+
kv_cache,
|
| 331 |
+
cache_idx,
|
| 332 |
+
attn_mask,
|
| 333 |
+
)
|
| 334 |
+
)
|
| 335 |
+
out = self.ffn_norm(h + self.feed_forward(h))
|
| 336 |
+
return out
|
| 337 |
+
|
| 338 |
+
def prefill(self, x: Array, attn_mask: Array, kv_cache: KVCache | KVCacheQ):
|
| 339 |
+
h = self.attention_norm(
|
| 340 |
+
x
|
| 341 |
+
+ self.attention.prefill(
|
| 342 |
+
x,
|
| 343 |
+
kv_cache,
|
| 344 |
+
attn_mask,
|
| 345 |
+
)
|
| 346 |
+
)
|
| 347 |
+
out = self.ffn_norm(h + self.feed_forward(h))
|
| 348 |
+
|
| 349 |
+
return out
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class TransformerDecoderABC(nn.Module):
|
| 353 |
+
def __init__(
|
| 354 |
+
self,
|
| 355 |
+
hidden_dim: int,
|
| 356 |
+
n_layer: int,
|
| 357 |
+
n_head: int,
|
| 358 |
+
ffn_dim: int,
|
| 359 |
+
vocab_size: int,
|
| 360 |
+
max_seq_length: int,
|
| 361 |
+
max_batch_size: int,
|
| 362 |
+
*args,
|
| 363 |
+
**kwds,
|
| 364 |
+
) -> None:
|
| 365 |
+
super().__init__()
|
| 366 |
+
|
| 367 |
+
self.hidden_dim = hidden_dim
|
| 368 |
+
self.n_head = n_head
|
| 369 |
+
assert hidden_dim % n_head == 0
|
| 370 |
+
|
| 371 |
+
self.head_dim = hidden_dim // n_head
|
| 372 |
+
self.vocab_size = vocab_size
|
| 373 |
+
|
| 374 |
+
self.n_layer = n_layer
|
| 375 |
+
|
| 376 |
+
self.layers: MutableSequence[TransformerBlockABC]
|
| 377 |
+
|
| 378 |
+
self.max_seq_length = max_seq_length
|
| 379 |
+
self.max_batch_size = max_batch_size
|
| 380 |
+
|
| 381 |
+
def __call__(
|
| 382 |
+
self,
|
| 383 |
+
input_pos: Array,
|
| 384 |
+
x: Array,
|
| 385 |
+
kv_caches: MutableSequence[KVCache | KVCacheQ],
|
| 386 |
+
cache_idx: Array,
|
| 387 |
+
*args,
|
| 388 |
+
**kwds,
|
| 389 |
+
):
|
| 390 |
+
for layer, kv_cache in zip(self.layers, kv_caches):
|
| 391 |
+
x = layer(
|
| 392 |
+
x,
|
| 393 |
+
input_pos,
|
| 394 |
+
kv_cache,
|
| 395 |
+
cache_idx,
|
| 396 |
+
*args,
|
| 397 |
+
**kwds,
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
return x
|
| 401 |
+
|
| 402 |
+
def prefill(self, x: Array, mask: Array, kv_caches: MutableSequence[KVCache | KVCacheQ]):
|
| 403 |
+
for layer, kv_cache in zip(self.layers, kv_caches):
|
| 404 |
+
x = layer.prefill(
|
| 405 |
+
x,
|
| 406 |
+
mask,
|
| 407 |
+
kv_cache,
|
| 408 |
+
)
|
| 409 |
+
return x
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class T2SDecoderABC(nn.Module, T2SDecoderProtocol):
|
| 413 |
+
def __init__(
|
| 414 |
+
self,
|
| 415 |
+
config: dict,
|
| 416 |
+
max_seq_length: int = 2000,
|
| 417 |
+
max_batch_size: int = 10,
|
| 418 |
+
) -> None:
|
| 419 |
+
super().__init__()
|
| 420 |
+
|
| 421 |
+
hidden_dim: int = config["model"]["hidden_dim"]
|
| 422 |
+
embedding_dim: int = config["model"]["embedding_dim"]
|
| 423 |
+
n_head: int = config["model"]["head"]
|
| 424 |
+
n_layer: int = config["model"]["n_layer"]
|
| 425 |
+
vocab_size: int = config["model"]["vocab_size"]
|
| 426 |
+
phoneme_vocab_size: int = config["model"]["phoneme_vocab_size"]
|
| 427 |
+
EOS: int = config["model"]["EOS"]
|
| 428 |
+
ffn_dim: int = hidden_dim * 4
|
| 429 |
+
|
| 430 |
+
self.n_layer = int(n_layer)
|
| 431 |
+
self.hidden_dim = int(hidden_dim)
|
| 432 |
+
self.n_head = int(n_head)
|
| 433 |
+
assert hidden_dim % n_head == 0
|
| 434 |
+
|
| 435 |
+
self.head_dim = int(hidden_dim // n_head)
|
| 436 |
+
self.embedding_dim = int(embedding_dim)
|
| 437 |
+
self.ffn_dim = int(ffn_dim)
|
| 438 |
+
self.vocab_size = int(vocab_size)
|
| 439 |
+
self.phoneme_vocab_size = int(phoneme_vocab_size)
|
| 440 |
+
self.max_seq_length = max_seq_length
|
| 441 |
+
self.max_batch_size = max_batch_size
|
| 442 |
+
self.EOS = EOS
|
| 443 |
+
assert self.EOS == self.vocab_size - 1
|
| 444 |
+
|
| 445 |
+
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
| 446 |
+
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
| 447 |
+
self.h: TransformerDecoderABC
|
| 448 |
+
|
| 449 |
+
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
|
| 450 |
+
self.ar_text_position = SinePositionalEmbedding(
|
| 451 |
+
self.embedding_dim,
|
| 452 |
+
scale=False,
|
| 453 |
+
max_batch_size=max_batch_size,
|
| 454 |
+
max_seq_len=max_seq_length,
|
| 455 |
+
)
|
| 456 |
+
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size)
|
| 457 |
+
self.ar_audio_position = SinePositionalEmbedding(
|
| 458 |
+
self.embedding_dim,
|
| 459 |
+
scale=False,
|
| 460 |
+
max_batch_size=max_batch_size,
|
| 461 |
+
max_seq_len=max_seq_length,
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
self.kv_class: KVCacheProtocol
|
| 465 |
+
|
| 466 |
+
def init_cache(self, bsz: int = 0, *args, **kwds) -> MutableSequence[KVCache | KVCacheQ]:
|
| 467 |
+
bsz = bsz or self.h.max_batch_size
|
| 468 |
+
assert bsz <= self.h.max_batch_size
|
| 469 |
+
seq_lens = self.h.max_seq_length
|
| 470 |
+
dtype = self.bert_proj.bias.dtype
|
| 471 |
+
cache: MutableSequence[KVCache | KVCacheQ] = [
|
| 472 |
+
self.kv_class.init_cache(bsz, seq_lens, self.n_head, self.head_dim, dtype, *args, **kwds)
|
| 473 |
+
for _ in range(self.n_layer)
|
| 474 |
+
]
|
| 475 |
+
mx.eval(cache)
|
| 476 |
+
return cache
|
| 477 |
+
|
| 478 |
+
def embed(
|
| 479 |
+
self,
|
| 480 |
+
x: list[Array],
|
| 481 |
+
y: Array,
|
| 482 |
+
bert_features: list[Array],
|
| 483 |
+
):
|
| 484 |
+
x_len: list[int] = [cast(tuple[int, ...], i.shape)[0] for i in x]
|
| 485 |
+
x_len_max = max(x_len)
|
| 486 |
+
xy_pos = mx.zeros((len(x), x_len_max + cast(tuple[int, ...], y.shape)[1], self.embedding_dim)).astype(
|
| 487 |
+
bert_features[0].dtype
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
bert_features = list(map(lambda x: x.swapaxes(0, 1), bert_features))
|
| 491 |
+
|
| 492 |
+
y_len = cast(tuple[int, ...], y.shape)[1]
|
| 493 |
+
y_emb = self.ar_audio_embedding(y)
|
| 494 |
+
y_pos = self.ar_audio_position.prefill(y_emb)
|
| 495 |
+
|
| 496 |
+
for bs, (x_, len_, bert_feature) in enumerate(zip(x, x_len, bert_features)):
|
| 497 |
+
x_emb = self.ar_text_embedding(x_)
|
| 498 |
+
bert = self.bert_proj(bert_feature)
|
| 499 |
+
x_emb = x_emb + bert
|
| 500 |
+
x_pos = self.ar_text_position.prefill(mx.expand_dims(x_emb, 0))
|
| 501 |
+
xy_pos[[bs], :len_] = x_pos
|
| 502 |
+
xy_pos[[bs], len_ : len_ + y_len] = y_pos
|
| 503 |
+
|
| 504 |
+
mx.eval(xy_pos)
|
| 505 |
+
return xy_pos
|
| 506 |
+
|
| 507 |
+
def compile(self):
|
| 508 |
+
setattr(self.h, "__call__", mx.compile(self.h.__call__))
|
| 509 |
+
# setattr(self.h, "prefill", mx.compile(self.h.prefill, shapeless=True))
|
| 510 |
+
|
| 511 |
+
def pre_forward(self, session: T2SSessionMLX):
|
| 512 |
+
attn_mask = session.attn_mask
|
| 513 |
+
return list(), dict(attn_mask=attn_mask)
|
| 514 |
+
|
| 515 |
+
def post_forward(self, idx: int, session: T2SSessionMLX) -> None:
|
| 516 |
+
if idx == 0:
|
| 517 |
+
prefill_len = session.prefill_len
|
| 518 |
+
bsz = session.bsz
|
| 519 |
+
|
| 520 |
+
range_tensor = mx.arange(self.max_seq_length).reshape(1, 1, 1, self.max_seq_length)
|
| 521 |
+
prefill_len_expanded = prefill_len.reshape(bsz, 1, 1, 1)
|
| 522 |
+
attn_mask = range_tensor < prefill_len_expanded
|
| 523 |
+
attn_mask = mx.repeat(attn_mask, self.n_head, 1)
|
| 524 |
+
|
| 525 |
+
session.attn_mask = attn_mask
|
| 526 |
+
|
| 527 |
+
attn_mask = session.attn_mask
|
| 528 |
+
input_pos = session.input_pos
|
| 529 |
+
attn_mask[mx.arange(session.bsz), :, :, input_pos] = True
|
| 530 |
+
mx.eval(attn_mask)
|
GPT_SoVITS/Accelerate/PyTorch/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.util
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .sample_funcs import sample_naive
|
| 6 |
+
from .structs import T2SRequest, T2SResult
|
| 7 |
+
from .t2s_engine import T2SEngine as T2SEngineTorch
|
| 8 |
+
|
| 9 |
+
torch.set_grad_enabled(False)
|
| 10 |
+
|
| 11 |
+
backends = ["torch_varlen"]
|
| 12 |
+
if torch.cuda.is_available():
|
| 13 |
+
backends.append("torch_static_cuda_graph")
|
| 14 |
+
# if importlib.util.find_spec("sageattention") is not None:
|
| 15 |
+
# for i in range(torch.cuda.device_count()):
|
| 16 |
+
# major, minor = torch.cuda.get_device_capability(i)
|
| 17 |
+
# sm_version = major + minor / 10.0
|
| 18 |
+
# if sm_version >= 7.0:
|
| 19 |
+
# backends.append("sage_attn_varlen_cuda_graph")
|
| 20 |
+
if importlib.util.find_spec("flash_attn") is not None:
|
| 21 |
+
for i in range(torch.cuda.device_count()):
|
| 22 |
+
major, minor = torch.cuda.get_device_capability(i)
|
| 23 |
+
sm_version = major + minor / 10.0
|
| 24 |
+
if sm_version >= 7.5:
|
| 25 |
+
backends.append("flash_attn_varlen_cuda_graph")
|
| 26 |
+
# if torch.mps.is_available():
|
| 27 |
+
# backends.append("mps_flash_attn_varlen")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
__all__ = ["T2SEngineTorch", "T2SRequest", "sample_naive", "T2SResult", "backends"]
|
GPT_SoVITS/Accelerate/PyTorch/backends/flash_attn_varlen_cuda_graph.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Modified From https://github.com/XXXXRT666/GPT-SoVITS
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Dict, List, Tuple
|
| 6 |
+
|
| 7 |
+
import kernels
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from .. import nn
|
| 11 |
+
from ..structs import T2SSession
|
| 12 |
+
from ..t2s_model_abc import (
|
| 13 |
+
AttentionABC,
|
| 14 |
+
CUDAGraphCacheABC,
|
| 15 |
+
FeedForward,
|
| 16 |
+
KVCacheNHD,
|
| 17 |
+
KVCacheProtocol,
|
| 18 |
+
T2SDecoderABC,
|
| 19 |
+
TransformerBlockABC,
|
| 20 |
+
TransformerDecoderABC,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
flash_attn_kernel = None
|
| 24 |
+
try:
|
| 25 |
+
import flash_attn_interface as flash_attn # type: ignore
|
| 26 |
+
|
| 27 |
+
flash_attn_kernel = flash_attn.flash_attn_with_kvcache
|
| 28 |
+
except ModuleNotFoundError:
|
| 29 |
+
try:
|
| 30 |
+
import flash_attn # type: ignore
|
| 31 |
+
|
| 32 |
+
flash_attn_kernel = flash_attn.flash_attn_with_kvcache
|
| 33 |
+
|
| 34 |
+
except ModuleNotFoundError:
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
if flash_attn_kernel is None:
|
| 38 |
+
flash_attn_kernel = kernels.get_kernel("kernels-community/flash-attn").flash_attn_with_kvcache
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
Tensor = torch.Tensor
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Attention(AttentionABC):
|
| 45 |
+
def __init__(self, n_head, hidden_dim, max_seq_length):
|
| 46 |
+
super().__init__(n_head, hidden_dim, max_seq_length)
|
| 47 |
+
|
| 48 |
+
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
| 49 |
+
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 50 |
+
|
| 51 |
+
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds) -> Tensor:
|
| 52 |
+
bsz, seqlen, _ = x.shape
|
| 53 |
+
|
| 54 |
+
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
| 55 |
+
|
| 56 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 57 |
+
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 58 |
+
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 59 |
+
|
| 60 |
+
attn: Tensor = flash_attn.flash_attn_with_kvcache( # type: ignore
|
| 61 |
+
q, kv_cache.k_cache, kv_cache.v_cache, k, v, cache_seqlens=input_pos - 1
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
attn = attn.view(bsz, seqlen, self.hidden_dim)
|
| 65 |
+
|
| 66 |
+
attn = self.out_proj(attn)
|
| 67 |
+
|
| 68 |
+
return attn
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class TransformerBlock(TransformerBlockABC):
|
| 72 |
+
def __init__(self, n_head, ffn_dim, hidden_dim, max_seq_length) -> None:
|
| 73 |
+
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
| 74 |
+
|
| 75 |
+
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
| 76 |
+
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
| 77 |
+
self.attention_norm = nn.LayerNorm([self.hidden_dim])
|
| 78 |
+
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class TransformerDecoder(TransformerDecoderABC):
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
hidden_dim,
|
| 85 |
+
n_layer,
|
| 86 |
+
n_head,
|
| 87 |
+
ffn_dim,
|
| 88 |
+
vocab_size,
|
| 89 |
+
max_seq_length,
|
| 90 |
+
max_batch_size,
|
| 91 |
+
) -> None:
|
| 92 |
+
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
|
| 93 |
+
|
| 94 |
+
self.layers = nn.ModuleList( # type: ignore
|
| 95 |
+
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class T2SDecoder(T2SDecoderABC):
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
config,
|
| 103 |
+
max_seq_length=2000,
|
| 104 |
+
max_batch_size=10,
|
| 105 |
+
) -> None:
|
| 106 |
+
assert torch.cuda.is_available()
|
| 107 |
+
super().__init__(config, max_seq_length, max_batch_size)
|
| 108 |
+
|
| 109 |
+
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
| 110 |
+
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
| 111 |
+
self.h: TransformerDecoderABC = TransformerDecoder(
|
| 112 |
+
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self.kv_class = KVCacheNHD
|
| 116 |
+
|
| 117 |
+
def post_forward(self, idx: int, session: T2SSession) -> None:
|
| 118 |
+
return super().post_forward(idx, session)
|
| 119 |
+
|
| 120 |
+
def pre_forward(self, session: T2SSession) -> Tuple[List, Dict]:
|
| 121 |
+
return super().pre_forward(session)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class CUDAGraphCache(CUDAGraphCacheABC):
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
decoder: T2SDecoder,
|
| 128 |
+
) -> None:
|
| 129 |
+
self.is_applicable = True
|
| 130 |
+
super().__init__(decoder)
|
| 131 |
+
|
| 132 |
+
def release_graph(self, session: T2SSession):
|
| 133 |
+
if session.id == self.id:
|
| 134 |
+
self.assigned = False
|
| 135 |
+
else:
|
| 136 |
+
del session.graph, session.xy_pos_, session.xy_dec_, session.input_pos, session.kv_cache
|
| 137 |
+
|
| 138 |
+
def get_cache_graph(self, session: T2SSession):
|
| 139 |
+
assert self.graph
|
| 140 |
+
session.graph = self.graph
|
| 141 |
+
session.stream = self.stream
|
| 142 |
+
|
| 143 |
+
session.xy_pos_ = self.xy_pos
|
| 144 |
+
session.xy_dec_ = self.xy_dec
|
| 145 |
+
session.input_pos = self.input_pos.copy_(session.input_pos)
|
| 146 |
+
|
| 147 |
+
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
|
| 148 |
+
cache.sync_cache(cache_)
|
| 149 |
+
|
| 150 |
+
def capture_new_graph(self, session: T2SSession):
|
| 151 |
+
session.xy_pos_ = self.xy_pos.clone()
|
| 152 |
+
session.xy_dec_ = self.xy_dec.clone()
|
| 153 |
+
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
| 154 |
+
|
| 155 |
+
args, kwds = self.decoder.pre_forward(session)
|
| 156 |
+
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
| 157 |
+
session.graph = graph
|
| 158 |
+
session.stream = torch.cuda.Stream() # type: ignore
|
GPT_SoVITS/Accelerate/PyTorch/backends/mps_flash_attn_varlen.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.nn import functional as F
|
| 3 |
+
|
| 4 |
+
from .. import nn
|
| 5 |
+
from ..structs import KVCacheProtocol, T2SSession
|
| 6 |
+
from ..t2s_model_abc import (
|
| 7 |
+
AttentionABC,
|
| 8 |
+
CUDAGraphCacheABC,
|
| 9 |
+
FeedForward,
|
| 10 |
+
KVCacheHND,
|
| 11 |
+
T2SDecoderABC,
|
| 12 |
+
TransformerBlockABC,
|
| 13 |
+
TransformerDecoderABC,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
Tensor = torch.Tensor
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Attention(AttentionABC):
|
| 20 |
+
def __init__(self, n_head, hidden_dim, max_seq_length):
|
| 21 |
+
super().__init__(n_head, hidden_dim, max_seq_length)
|
| 22 |
+
|
| 23 |
+
# key, query, value projections for all heads, but in a batch
|
| 24 |
+
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
| 25 |
+
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 26 |
+
|
| 27 |
+
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
|
| 28 |
+
bsz, seqlen, _ = x.shape
|
| 29 |
+
|
| 30 |
+
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
| 31 |
+
|
| 32 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 33 |
+
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 34 |
+
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 35 |
+
|
| 36 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
| 37 |
+
|
| 38 |
+
k, v = kv_cache.update(input_pos, k, v)
|
| 39 |
+
|
| 40 |
+
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
| 41 |
+
|
| 42 |
+
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
|
| 43 |
+
|
| 44 |
+
attn = self.out_proj(attn)
|
| 45 |
+
|
| 46 |
+
return attn
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class TransformerBlock(TransformerBlockABC):
|
| 50 |
+
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
| 51 |
+
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
| 52 |
+
|
| 53 |
+
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
| 54 |
+
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
| 55 |
+
self.attention_norm = nn.LayerNorm([self.hidden_dim])
|
| 56 |
+
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class TransformerDecoder(TransformerDecoderABC):
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
hidden_dim,
|
| 63 |
+
n_layer,
|
| 64 |
+
n_head,
|
| 65 |
+
ffn_dim,
|
| 66 |
+
vocab_size,
|
| 67 |
+
max_seq_length,
|
| 68 |
+
max_batch_size,
|
| 69 |
+
) -> None:
|
| 70 |
+
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
|
| 71 |
+
|
| 72 |
+
self.layers = nn.ModuleList( # type: ignore
|
| 73 |
+
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class T2SDecoder(T2SDecoderABC):
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
config,
|
| 81 |
+
max_seq_length=2000,
|
| 82 |
+
max_batch_size=10,
|
| 83 |
+
) -> None:
|
| 84 |
+
super().__init__(config, max_seq_length, max_batch_size)
|
| 85 |
+
|
| 86 |
+
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
| 87 |
+
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
| 88 |
+
self.h: TransformerDecoderABC = TransformerDecoder(
|
| 89 |
+
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.kv_class = KVCacheHND
|
| 93 |
+
|
| 94 |
+
def pre_forward(self, session: T2SSession):
|
| 95 |
+
attn_mask = session.attn_mask
|
| 96 |
+
return list(), dict(attn_mask=attn_mask)
|
| 97 |
+
|
| 98 |
+
def post_forward(self, idx: int, session: T2SSession) -> None:
|
| 99 |
+
if idx == 0:
|
| 100 |
+
prefill_len = session.prefill_len
|
| 101 |
+
bsz = session.bsz
|
| 102 |
+
|
| 103 |
+
range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
|
| 104 |
+
prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
|
| 105 |
+
attn_mask = range_tensor < prefill_len_expanded
|
| 106 |
+
attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
|
| 107 |
+
|
| 108 |
+
session.attn_mask = attn_mask
|
| 109 |
+
|
| 110 |
+
attn_mask = session.attn_mask
|
| 111 |
+
input_pos = session.input_pos
|
| 112 |
+
attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class CUDAGraphCache(CUDAGraphCacheABC):
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
decoder,
|
| 119 |
+
) -> None:
|
| 120 |
+
self.is_applicable = False
|
| 121 |
+
super().__init__(decoder)
|
| 122 |
+
if torch.cuda.is_available():
|
| 123 |
+
self.attn_mask = (
|
| 124 |
+
torch.randint(0, 2, (decoder.max_batch_size, decoder.n_head, 1, decoder.max_seq_length))
|
| 125 |
+
.bool()
|
| 126 |
+
.to(self.device, self.dtype)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def release_graph(self, session: T2SSession):
|
| 130 |
+
if session.id == self.id:
|
| 131 |
+
self.assigned = False
|
| 132 |
+
else:
|
| 133 |
+
del (
|
| 134 |
+
session.graph,
|
| 135 |
+
session.xy_pos_,
|
| 136 |
+
session.xy_dec_,
|
| 137 |
+
session.input_pos,
|
| 138 |
+
session.kv_cache,
|
| 139 |
+
session.attn_mask,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def get_cache_graph(self, session: T2SSession):
|
| 143 |
+
assert self.graph
|
| 144 |
+
session.graph = self.graph
|
| 145 |
+
session.stream = self.stream
|
| 146 |
+
|
| 147 |
+
session.xy_pos_ = self.xy_pos
|
| 148 |
+
session.xy_dec_ = self.xy_dec
|
| 149 |
+
session.input_pos = self.input_pos.copy_(session.input_pos)
|
| 150 |
+
|
| 151 |
+
session.attn_mask = self.attn_mask
|
| 152 |
+
|
| 153 |
+
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
|
| 154 |
+
cache.sync_cache(cache_)
|
| 155 |
+
|
| 156 |
+
def capture_new_graph(self, session: T2SSession):
|
| 157 |
+
session.xy_pos_ = self.xy_pos.clone()
|
| 158 |
+
session.xy_dec_ = self.xy_dec.clone()
|
| 159 |
+
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
| 160 |
+
|
| 161 |
+
session.attn_mask = self.attn_mask.clone().copy_(session.attn_mask)
|
| 162 |
+
|
| 163 |
+
args, kwds = self.decoder.pre_forward(session)
|
| 164 |
+
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
| 165 |
+
session.graph = graph
|
| 166 |
+
session.stream = torch.cuda.Stream() # type: ignore
|
GPT_SoVITS/Accelerate/PyTorch/backends/sage_attn_varlen_cuda_graph.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sageattention # type: ignore
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from .. import nn
|
| 5 |
+
from ..structs import T2SSession
|
| 6 |
+
from ..t2s_model_abc import (
|
| 7 |
+
AttentionABC,
|
| 8 |
+
CUDAGraphCacheABC,
|
| 9 |
+
FeedForward,
|
| 10 |
+
KVCacheHND,
|
| 11 |
+
KVCacheProtocol,
|
| 12 |
+
T2SDecoderABC,
|
| 13 |
+
TransformerBlockABC,
|
| 14 |
+
TransformerDecoderABC,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
Tensor = torch.Tensor
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Attention(AttentionABC):
|
| 21 |
+
def __init__(self, n_head, hidden_dim, max_seq_length):
|
| 22 |
+
super().__init__(n_head, hidden_dim, max_seq_length)
|
| 23 |
+
|
| 24 |
+
# key, query, value projections for all heads, but in a batch
|
| 25 |
+
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
| 26 |
+
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 27 |
+
|
| 28 |
+
def __call__(
|
| 29 |
+
self,
|
| 30 |
+
x: Tensor,
|
| 31 |
+
input_pos: Tensor,
|
| 32 |
+
kv_cache: KVCacheProtocol,
|
| 33 |
+
cu_seqlens_q: Tensor,
|
| 34 |
+
cu_seqlens_kv: Tensor,
|
| 35 |
+
) -> Tensor:
|
| 36 |
+
bsz, seqlen, _ = x.shape
|
| 37 |
+
|
| 38 |
+
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
| 39 |
+
|
| 40 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 41 |
+
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 42 |
+
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 43 |
+
|
| 44 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
| 45 |
+
|
| 46 |
+
k, v = kv_cache.update(input_pos, k, v)
|
| 47 |
+
|
| 48 |
+
attn: Tensor = sageattention.sageattn_varlen(
|
| 49 |
+
q,
|
| 50 |
+
k,
|
| 51 |
+
v,
|
| 52 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 53 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
| 54 |
+
max_seqlen_q=1,
|
| 55 |
+
max_seqlen_k=self.max_seq_length,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
|
| 59 |
+
|
| 60 |
+
attn = self.out_proj(attn)
|
| 61 |
+
|
| 62 |
+
return attn
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class TransformerBlock(TransformerBlockABC):
|
| 66 |
+
def __init__(self, n_head, ffn_dim, hidden_dim, max_seq_length) -> None:
|
| 67 |
+
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
| 68 |
+
|
| 69 |
+
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
| 70 |
+
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
| 71 |
+
self.attention_norm = nn.LayerNorm([self.hidden_dim])
|
| 72 |
+
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class TransformerDecoder(TransformerDecoderABC):
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
hidden_dim,
|
| 79 |
+
n_layer,
|
| 80 |
+
n_head,
|
| 81 |
+
ffn_dim,
|
| 82 |
+
vocab_size,
|
| 83 |
+
max_seq_length,
|
| 84 |
+
max_batch_size,
|
| 85 |
+
) -> None:
|
| 86 |
+
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
|
| 87 |
+
|
| 88 |
+
self.layers = nn.ModuleList( # type: ignore
|
| 89 |
+
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class T2SDecoder(T2SDecoderABC):
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
config,
|
| 97 |
+
max_seq_length=2000,
|
| 98 |
+
max_batch_size=10,
|
| 99 |
+
) -> None:
|
| 100 |
+
super().__init__(config, max_seq_length, max_batch_size)
|
| 101 |
+
|
| 102 |
+
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
| 103 |
+
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
| 104 |
+
self.h: TransformerDecoderABC = TransformerDecoder(
|
| 105 |
+
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
self.kv_class = KVCacheHND
|
| 109 |
+
|
| 110 |
+
def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
|
| 111 |
+
return list(), dict(cu_seqlens_q=session.cu_seqlens_q, cu_seqlens_kv=session.cu_seqlens_kv)
|
| 112 |
+
|
| 113 |
+
def post_forward(self, idx: int, session: T2SSession):
|
| 114 |
+
if idx == 0:
|
| 115 |
+
session.cu_seqlens_q = torch.arange(0, session.bsz + 1, dtype=torch.int32)
|
| 116 |
+
session.cu_seqlens_kv = torch.cat([torch.tensor(0, dtype=torch.int32), session.input_pos])
|
| 117 |
+
else:
|
| 118 |
+
cu_seqlens_q = session.cu_seqlens_q
|
| 119 |
+
cu_seqlens_kv = session.cu_seqlens_kv
|
| 120 |
+
cu_seqlens_kv.add_(cu_seqlens_q)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class CUDAGraphCache(CUDAGraphCacheABC):
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
decoder: T2SDecoder,
|
| 127 |
+
) -> None:
|
| 128 |
+
self.is_applicable = False
|
| 129 |
+
super().__init__(decoder)
|
| 130 |
+
|
| 131 |
+
if torch.cuda.is_available():
|
| 132 |
+
self.cu_seqlens_q = torch.arange(0, decoder.max_batch_size + 1, dtype=torch.int32).to(self.device)
|
| 133 |
+
self.cu_seqlens_kv = torch.cat([torch.tensor(0, dtype=torch.int32), self.input_pos]).to(self.device)
|
| 134 |
+
|
| 135 |
+
def release_graph(self, session: T2SSession):
|
| 136 |
+
if session.id == self.id:
|
| 137 |
+
self.assigned = False
|
| 138 |
+
else:
|
| 139 |
+
del (
|
| 140 |
+
session.graph,
|
| 141 |
+
session.xy_pos_,
|
| 142 |
+
session.xy_dec_,
|
| 143 |
+
session.input_pos,
|
| 144 |
+
session.kv_cache,
|
| 145 |
+
session.cu_seqlens_q,
|
| 146 |
+
session.cu_seqlens_kv,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def get_cache_graph(self, session: T2SSession):
|
| 150 |
+
assert self.graph
|
| 151 |
+
session.graph = self.graph
|
| 152 |
+
session.stream = self.stream
|
| 153 |
+
|
| 154 |
+
session.xy_pos_ = self.xy_pos
|
| 155 |
+
session.xy_dec_ = self.xy_dec
|
| 156 |
+
session.input_pos = self.input_pos.copy_(session.input_pos)
|
| 157 |
+
|
| 158 |
+
session.cu_seqlens_q = self.cu_seqlens_q
|
| 159 |
+
session.cu_seqlens_kv = self.cu_seqlens_kv
|
| 160 |
+
|
| 161 |
+
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
|
| 162 |
+
cache.sync_cache(cache_)
|
| 163 |
+
|
| 164 |
+
def capture_new_graph(self, session: T2SSession):
|
| 165 |
+
session.xy_pos_ = self.xy_pos.clone()
|
| 166 |
+
session.xy_dec_ = self.xy_dec.clone()
|
| 167 |
+
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
| 168 |
+
|
| 169 |
+
session.cu_seqlens_q = self.cu_seqlens_q.clone().copy_(session.cu_seqlens_q)
|
| 170 |
+
session.cu_seqlens_kv = self.cu_seqlens_kv.clone().copy_(session.cu_seqlens_kv)
|
| 171 |
+
|
| 172 |
+
args, kwds = self.decoder.pre_forward(session)
|
| 173 |
+
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
| 174 |
+
session.graph = graph
|
| 175 |
+
session.stream = torch.cuda.Stream() # type: ignore
|
GPT_SoVITS/Accelerate/PyTorch/backends/torch_static_cuda_graph.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.nn import functional as F
|
| 3 |
+
|
| 4 |
+
from .. import nn
|
| 5 |
+
from ..structs import KVCacheProtocol, T2SSession
|
| 6 |
+
from ..t2s_model_abc import (
|
| 7 |
+
AttentionABC,
|
| 8 |
+
CUDAGraphCacheABC,
|
| 9 |
+
FeedForward,
|
| 10 |
+
KVCacheHND,
|
| 11 |
+
T2SDecoderABC,
|
| 12 |
+
TransformerBlockABC,
|
| 13 |
+
TransformerDecoderABC,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
Tensor = torch.Tensor
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Attention(AttentionABC):
|
| 20 |
+
def __init__(self, n_head, hidden_dim, max_seq_length):
|
| 21 |
+
super().__init__(n_head, hidden_dim, max_seq_length)
|
| 22 |
+
|
| 23 |
+
# key, query, value projections for all heads, but in a batch
|
| 24 |
+
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
| 25 |
+
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 26 |
+
|
| 27 |
+
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
|
| 28 |
+
bsz, seqlen, _ = x.shape
|
| 29 |
+
|
| 30 |
+
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
| 31 |
+
|
| 32 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 33 |
+
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 34 |
+
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 35 |
+
|
| 36 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
| 37 |
+
|
| 38 |
+
k, v = kv_cache.update(input_pos, k, v)
|
| 39 |
+
|
| 40 |
+
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
| 41 |
+
|
| 42 |
+
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
|
| 43 |
+
|
| 44 |
+
attn = self.out_proj(attn)
|
| 45 |
+
|
| 46 |
+
return attn
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class TransformerBlock(TransformerBlockABC):
|
| 50 |
+
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
| 51 |
+
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
| 52 |
+
|
| 53 |
+
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
| 54 |
+
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
| 55 |
+
self.attention_norm = nn.LayerNorm([self.hidden_dim])
|
| 56 |
+
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class TransformerDecoder(TransformerDecoderABC):
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
hidden_dim,
|
| 63 |
+
n_layer,
|
| 64 |
+
n_head,
|
| 65 |
+
ffn_dim,
|
| 66 |
+
vocab_size,
|
| 67 |
+
max_seq_length,
|
| 68 |
+
max_batch_size,
|
| 69 |
+
) -> None:
|
| 70 |
+
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
|
| 71 |
+
|
| 72 |
+
self.layers = nn.ModuleList( # type: ignore
|
| 73 |
+
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class T2SDecoder(T2SDecoderABC):
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
config,
|
| 81 |
+
max_seq_length=2000,
|
| 82 |
+
max_batch_size=10,
|
| 83 |
+
) -> None:
|
| 84 |
+
super().__init__(config, max_seq_length, max_batch_size)
|
| 85 |
+
|
| 86 |
+
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
| 87 |
+
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
| 88 |
+
self.h: TransformerDecoderABC = TransformerDecoder(
|
| 89 |
+
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.kv_class = KVCacheHND
|
| 93 |
+
|
| 94 |
+
def pre_forward(self, session: T2SSession):
|
| 95 |
+
attn_mask = session.attn_mask
|
| 96 |
+
return list(), dict(attn_mask=attn_mask)
|
| 97 |
+
|
| 98 |
+
def post_forward(self, idx: int, session: T2SSession) -> None:
|
| 99 |
+
if idx == 0:
|
| 100 |
+
prefill_len = session.prefill_len
|
| 101 |
+
bsz = session.bsz
|
| 102 |
+
|
| 103 |
+
range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
|
| 104 |
+
prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
|
| 105 |
+
attn_mask = range_tensor < prefill_len_expanded
|
| 106 |
+
attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
|
| 107 |
+
|
| 108 |
+
session.attn_mask = attn_mask
|
| 109 |
+
|
| 110 |
+
attn_mask = session.attn_mask
|
| 111 |
+
input_pos = session.input_pos
|
| 112 |
+
attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class CUDAGraphCache(CUDAGraphCacheABC):
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
decoder,
|
| 119 |
+
) -> None:
|
| 120 |
+
self.is_applicable = True
|
| 121 |
+
super().__init__(decoder)
|
| 122 |
+
if torch.cuda.is_available():
|
| 123 |
+
self.attn_mask = (
|
| 124 |
+
torch.randint(0, 2, (decoder.max_batch_size, decoder.n_head, 1, decoder.max_seq_length))
|
| 125 |
+
.bool()
|
| 126 |
+
.to(self.device, self.dtype)
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
def release_graph(self, session: T2SSession):
|
| 130 |
+
if session.id == self.id:
|
| 131 |
+
self.assigned = False
|
| 132 |
+
else:
|
| 133 |
+
del (
|
| 134 |
+
session.graph,
|
| 135 |
+
session.xy_pos_,
|
| 136 |
+
session.xy_dec_,
|
| 137 |
+
session.input_pos,
|
| 138 |
+
session.kv_cache,
|
| 139 |
+
session.attn_mask,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def get_cache_graph(self, session: T2SSession):
|
| 143 |
+
assert self.graph
|
| 144 |
+
session.graph = self.graph
|
| 145 |
+
session.stream = self.stream
|
| 146 |
+
|
| 147 |
+
session.xy_pos_ = self.xy_pos
|
| 148 |
+
session.xy_dec_ = self.xy_dec
|
| 149 |
+
session.input_pos = self.input_pos.copy_(session.input_pos)
|
| 150 |
+
|
| 151 |
+
session.attn_mask = self.attn_mask
|
| 152 |
+
|
| 153 |
+
for cache, cache_ in zip(self.kv_cache, session.kv_cache):
|
| 154 |
+
cache.sync_cache(cache_)
|
| 155 |
+
|
| 156 |
+
def capture_new_graph(self, session: T2SSession):
|
| 157 |
+
session.xy_pos_ = self.xy_pos.clone()
|
| 158 |
+
session.xy_dec_ = self.xy_dec.clone()
|
| 159 |
+
session.input_pos = self.input_pos.clone().copy_(session.input_pos)
|
| 160 |
+
|
| 161 |
+
session.attn_mask = self.attn_mask.clone().copy_(session.attn_mask)
|
| 162 |
+
|
| 163 |
+
args, kwds = self.decoder.pre_forward(session)
|
| 164 |
+
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
| 165 |
+
session.graph = graph
|
| 166 |
+
session.stream = torch.cuda.Stream() # type: ignore
|
GPT_SoVITS/Accelerate/PyTorch/backends/torch_varlen.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import NoReturn
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
from .. import nn
|
| 7 |
+
from ..structs import KVCacheProtocol, T2SSession
|
| 8 |
+
from ..t2s_model_abc import (
|
| 9 |
+
AttentionABC,
|
| 10 |
+
CUDAGraphCacheABC,
|
| 11 |
+
FeedForward,
|
| 12 |
+
KVCacheHNDVarlen,
|
| 13 |
+
T2SDecoderABC,
|
| 14 |
+
TransformerBlockABC,
|
| 15 |
+
TransformerDecoderABC,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
Tensor = torch.Tensor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Attention(AttentionABC):
|
| 22 |
+
def __init__(self, n_head, hidden_dim, max_seq_length):
|
| 23 |
+
super().__init__(n_head, hidden_dim, max_seq_length)
|
| 24 |
+
|
| 25 |
+
# key, query, value projections for all heads, but in a batch
|
| 26 |
+
self.in_proj = nn.Linear(hidden_dim, hidden_dim * 3, bias=True)
|
| 27 |
+
self.out_proj = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 28 |
+
|
| 29 |
+
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor):
|
| 30 |
+
bsz, seqlen, _ = x.shape
|
| 31 |
+
|
| 32 |
+
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
| 33 |
+
|
| 34 |
+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 35 |
+
k = k.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 36 |
+
v = v.view(bsz, seqlen, self.n_head, self.head_dim)
|
| 37 |
+
|
| 38 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
| 39 |
+
|
| 40 |
+
k, v = kv_cache.update(input_pos, k, v)
|
| 41 |
+
|
| 42 |
+
max_idx = input_pos.max()
|
| 43 |
+
|
| 44 |
+
q, k, v = map(lambda x: x[..., :max_idx, :], (q, k, v))
|
| 45 |
+
|
| 46 |
+
mask = attn_mask[..., :max_idx]
|
| 47 |
+
|
| 48 |
+
attn = F.scaled_dot_product_attention(q, k, v, mask)
|
| 49 |
+
|
| 50 |
+
attn = attn.transpose(1, 2).contiguous().view(bsz, seqlen, self.hidden_dim)
|
| 51 |
+
|
| 52 |
+
attn = self.out_proj(attn)
|
| 53 |
+
|
| 54 |
+
return attn
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class TransformerBlock(TransformerBlockABC):
|
| 58 |
+
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
| 59 |
+
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
| 60 |
+
|
| 61 |
+
self.attention = Attention(n_head, hidden_dim, max_seq_length)
|
| 62 |
+
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
| 63 |
+
self.attention_norm = nn.LayerNorm([self.hidden_dim])
|
| 64 |
+
self.ffn_norm = nn.LayerNorm([self.hidden_dim])
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class TransformerDecoder(TransformerDecoderABC):
|
| 68 |
+
def __init__(
|
| 69 |
+
self,
|
| 70 |
+
hidden_dim,
|
| 71 |
+
n_layer,
|
| 72 |
+
n_head,
|
| 73 |
+
ffn_dim,
|
| 74 |
+
vocab_size,
|
| 75 |
+
max_seq_length,
|
| 76 |
+
max_batch_size,
|
| 77 |
+
) -> None:
|
| 78 |
+
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
|
| 79 |
+
|
| 80 |
+
self.layers = nn.ModuleList( # type: ignore
|
| 81 |
+
TransformerBlock(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class T2SDecoder(T2SDecoderABC):
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
config,
|
| 89 |
+
max_seq_length=2000,
|
| 90 |
+
max_batch_size=10,
|
| 91 |
+
) -> None:
|
| 92 |
+
super().__init__(config, max_seq_length, max_batch_size)
|
| 93 |
+
|
| 94 |
+
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
| 95 |
+
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
| 96 |
+
self.h: TransformerDecoderABC = TransformerDecoder(
|
| 97 |
+
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
self.kv_class = KVCacheHNDVarlen
|
| 101 |
+
|
| 102 |
+
def capture(
|
| 103 |
+
self,
|
| 104 |
+
*args,
|
| 105 |
+
**kwds,
|
| 106 |
+
) -> NoReturn:
|
| 107 |
+
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
|
| 108 |
+
|
| 109 |
+
def pre_forward(self, session: T2SSession):
|
| 110 |
+
attn_mask = session.attn_mask
|
| 111 |
+
return list(), dict(attn_mask=attn_mask)
|
| 112 |
+
|
| 113 |
+
def post_forward(self, idx: int, session: T2SSession) -> None:
|
| 114 |
+
if idx == 0:
|
| 115 |
+
prefill_len = session.prefill_len
|
| 116 |
+
bsz = session.bsz
|
| 117 |
+
|
| 118 |
+
range_tensor = torch.arange(self.max_seq_length).view(1, 1, 1, self.max_seq_length)
|
| 119 |
+
prefill_len_expanded = prefill_len.view(bsz, 1, 1, 1)
|
| 120 |
+
attn_mask = range_tensor < prefill_len_expanded
|
| 121 |
+
attn_mask = attn_mask.expand(-1, self.n_head, -1, -1)
|
| 122 |
+
|
| 123 |
+
session.attn_mask = attn_mask
|
| 124 |
+
|
| 125 |
+
attn_mask = session.attn_mask
|
| 126 |
+
input_pos = session.input_pos
|
| 127 |
+
attn_mask[torch.arange(session.bsz), :, :, input_pos] = True
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class CUDAGraphCache(CUDAGraphCacheABC):
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
decoder,
|
| 134 |
+
) -> None:
|
| 135 |
+
self.is_applicable = False
|
| 136 |
+
super().__init__(decoder)
|
| 137 |
+
|
| 138 |
+
def release_graph(self, session: T2SSession):
|
| 139 |
+
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
|
| 140 |
+
|
| 141 |
+
def get_cache_graph(self, session: T2SSession):
|
| 142 |
+
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
|
| 143 |
+
|
| 144 |
+
def capture_new_graph(self, session: T2SSession):
|
| 145 |
+
raise NotImplementedError("Cuda Graph Is Not Supported For Varlen Model")
|
GPT_SoVITS/Accelerate/PyTorch/export.py
ADDED
|
@@ -0,0 +1,467 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import enum
|
| 2 |
+
import os
|
| 3 |
+
import os.path as osp
|
| 4 |
+
import time
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import MutableSequence, TypeAlias
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import typer
|
| 10 |
+
from torch.export import Dim
|
| 11 |
+
from torch.nn import functional as F
|
| 12 |
+
|
| 13 |
+
from ..logger import logger
|
| 14 |
+
from . import nn
|
| 15 |
+
from .t2s_model_abc import AttentionABC, FeedForward, T2SDecoderABC, TransformerBlockABC, TransformerDecoderABC
|
| 16 |
+
|
| 17 |
+
Tensor = torch.Tensor
|
| 18 |
+
|
| 19 |
+
KVCache: TypeAlias = tuple[Tensor, Tensor]
|
| 20 |
+
|
| 21 |
+
app = typer.Typer(
|
| 22 |
+
context_settings={"help_option_names": ["-h", "--help"]},
|
| 23 |
+
add_completion=False,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class Stage(str, enum.Enum):
|
| 28 |
+
embed = "embed"
|
| 29 |
+
decode = "decode"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class KVCacheONNX:
|
| 33 |
+
@staticmethod
|
| 34 |
+
def empty(kv_cache):
|
| 35 |
+
assert len(kv_cache) == 2
|
| 36 |
+
k_cache, v_cache = kv_cache
|
| 37 |
+
|
| 38 |
+
k_cache[:] = 0
|
| 39 |
+
v_cache[:] = 0
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def update_cache(
|
| 43 |
+
input_pos: Tensor, k_val: Tensor, v_val: Tensor, kv_cache: tuple[Tensor, Tensor], cache_idx: Tensor
|
| 44 |
+
):
|
| 45 |
+
# input_pos: [B, ], k_val: [B, H, 1, D]
|
| 46 |
+
k_out, v_out = kv_cache
|
| 47 |
+
ip0 = input_pos - 1
|
| 48 |
+
|
| 49 |
+
k_out[cache_idx, :, ip0, None] = k_val
|
| 50 |
+
v_out[cache_idx, :, ip0, None] = v_val
|
| 51 |
+
|
| 52 |
+
return k_out, v_out
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def prefill_kv(k_val: Tensor, v_val: Tensor, kv_cache: tuple[Tensor, Tensor]):
|
| 56 |
+
# k_val: [B, S, H, D]
|
| 57 |
+
k_cache, v_cache = kv_cache
|
| 58 |
+
|
| 59 |
+
k_cache[..., : k_val.shape[1], :] = k_val.transpose(1, 2)
|
| 60 |
+
v_cache[..., : v_val.shape[1], :] = v_val.transpose(1, 2)
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
def init_cache(batch_size: int, max_seq_length: int, n_heads: int, head_dim: int, dtype: torch.dtype):
|
| 64 |
+
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
|
| 65 |
+
|
| 66 |
+
return (torch.zeros(cache_shape, dtype=dtype), torch.zeros(cache_shape, dtype=dtype))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class AttentionONNX(AttentionABC):
|
| 70 |
+
def __init__(self, n_heads: int, head_dim: int, max_seq_length: int):
|
| 71 |
+
super().__init__(n_heads, head_dim, max_seq_length)
|
| 72 |
+
|
| 73 |
+
self.in_proj = nn.Linear(self.hidden_dim, self.hidden_dim * 3, bias=True)
|
| 74 |
+
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=True)
|
| 75 |
+
|
| 76 |
+
def __call__(self, *args, **kwds): # type: ignore
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
def onnx_prefill(self, x: Tensor, kv_cache: KVCache, attn_mask: Tensor) -> Tensor:
|
| 80 |
+
bsz, seqlen, _ = x.shape
|
| 81 |
+
|
| 82 |
+
torch._check(attn_mask.size(-2) == x.size(-2))
|
| 83 |
+
|
| 84 |
+
q, k, v = self.in_proj(x.unsqueeze(0)).chunk(3, dim=-1)
|
| 85 |
+
|
| 86 |
+
q, k, v = map(lambda x: x.contiguous().view(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
| 87 |
+
|
| 88 |
+
KVCacheONNX.prefill_kv(k, v, kv_cache)
|
| 89 |
+
|
| 90 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
| 91 |
+
|
| 92 |
+
attn = F.scaled_dot_product_attention(q, k, v, attn_mask)
|
| 93 |
+
|
| 94 |
+
attn = attn.transpose(1, 2).contiguous().view(1, -1, self.hidden_dim)
|
| 95 |
+
|
| 96 |
+
output = self.out_proj(attn)
|
| 97 |
+
|
| 98 |
+
return output
|
| 99 |
+
|
| 100 |
+
def onnx_decode(self, x: Tensor, input_pos: Tensor, kv_cache: KVCache, cache_idx: Tensor, attn_mask: Tensor):
|
| 101 |
+
bsz, seqlen, _ = x.shape
|
| 102 |
+
|
| 103 |
+
torch._check(attn_mask.size(-2) == 1)
|
| 104 |
+
|
| 105 |
+
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
| 106 |
+
|
| 107 |
+
q, k, v = map(lambda x: x.reshape(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
| 108 |
+
|
| 109 |
+
q, k, v = map(lambda x: x.swapaxes(1, 2), (q, k, v))
|
| 110 |
+
|
| 111 |
+
kv_cache = KVCacheONNX.update_cache(input_pos, k, v, kv_cache, cache_idx)
|
| 112 |
+
|
| 113 |
+
max_idx = int(input_pos.max())
|
| 114 |
+
|
| 115 |
+
q, k, v = map(lambda x: x[..., :max_idx, :], (q, *kv_cache))
|
| 116 |
+
|
| 117 |
+
mask = attn_mask[..., :max_idx]
|
| 118 |
+
|
| 119 |
+
attn = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
| 120 |
+
|
| 121 |
+
attn = attn.swapaxes(1, 2).reshape(bsz, seqlen, self.hidden_dim)
|
| 122 |
+
|
| 123 |
+
attn = self.out_proj(attn)
|
| 124 |
+
|
| 125 |
+
return attn
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class TransformerBlockONNX(TransformerBlockABC):
|
| 129 |
+
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
| 130 |
+
super().__init__(n_head, ffn_dim, hidden_dim, max_seq_length)
|
| 131 |
+
|
| 132 |
+
self.attention: AttentionONNX = AttentionONNX(n_head, hidden_dim, max_seq_length) # type: ignore
|
| 133 |
+
self.feed_forward = FeedForward(hidden_dim, ffn_dim)
|
| 134 |
+
self.attention_norm = nn.LayerNorm(self.hidden_dim)
|
| 135 |
+
self.ffn_norm = nn.LayerNorm(self.hidden_dim)
|
| 136 |
+
|
| 137 |
+
def onnx_prefill(self, x: Tensor, attn_mask: Tensor, kv_cache: KVCache):
|
| 138 |
+
h = self.attention_norm(
|
| 139 |
+
x
|
| 140 |
+
+ self.attention.onnx_prefill(
|
| 141 |
+
x,
|
| 142 |
+
kv_cache,
|
| 143 |
+
attn_mask,
|
| 144 |
+
)
|
| 145 |
+
)
|
| 146 |
+
out = self.ffn_norm(h + self.feed_forward(h))
|
| 147 |
+
|
| 148 |
+
return out
|
| 149 |
+
|
| 150 |
+
def onnx_decode(self, x: Tensor, input_pos: Tensor, kv_cache: KVCache, cache_idx: Tensor, attn_mask: Tensor):
|
| 151 |
+
h = self.attention_norm(
|
| 152 |
+
x
|
| 153 |
+
+ self.attention.onnx_decode(
|
| 154 |
+
x,
|
| 155 |
+
input_pos,
|
| 156 |
+
kv_cache,
|
| 157 |
+
cache_idx,
|
| 158 |
+
attn_mask,
|
| 159 |
+
)
|
| 160 |
+
)
|
| 161 |
+
out = self.ffn_norm(h + self.feed_forward(h))
|
| 162 |
+
return out
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class TransformerDecoderONNX(TransformerDecoderABC):
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
hidden_dim: int,
|
| 169 |
+
n_layer: int,
|
| 170 |
+
n_head: int,
|
| 171 |
+
ffn_dim: int,
|
| 172 |
+
vocab_size: int,
|
| 173 |
+
max_seq_length: int,
|
| 174 |
+
max_batch_size: int,
|
| 175 |
+
) -> None:
|
| 176 |
+
super().__init__(hidden_dim, n_layer, n_head, ffn_dim, vocab_size, max_seq_length, max_batch_size)
|
| 177 |
+
|
| 178 |
+
self.layers: MutableSequence[TransformerBlockONNX] = nn.ModuleList( # type: ignore
|
| 179 |
+
TransformerBlockONNX(n_head, ffn_dim, hidden_dim, max_seq_length) for _ in range(n_layer)
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def onnx_prefill(self, x: Tensor, mask: Tensor, *kv_caches: KVCache):
|
| 183 |
+
for layer, kv_cache in zip(self.layers, kv_caches):
|
| 184 |
+
x = layer.onnx_prefill(
|
| 185 |
+
x,
|
| 186 |
+
mask,
|
| 187 |
+
kv_cache,
|
| 188 |
+
)
|
| 189 |
+
return x
|
| 190 |
+
|
| 191 |
+
def onnx_decode(
|
| 192 |
+
self,
|
| 193 |
+
input_pos: Tensor,
|
| 194 |
+
x: Tensor,
|
| 195 |
+
cache_idx: Tensor,
|
| 196 |
+
attn_mask: Tensor,
|
| 197 |
+
*kv_caches: KVCache,
|
| 198 |
+
):
|
| 199 |
+
for layer, kv_cache in zip(self.layers, kv_caches):
|
| 200 |
+
x = layer.onnx_decode(
|
| 201 |
+
x,
|
| 202 |
+
input_pos,
|
| 203 |
+
kv_cache,
|
| 204 |
+
cache_idx,
|
| 205 |
+
attn_mask,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
return x
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class T2SDecoderONNX(T2SDecoderABC):
|
| 212 |
+
def __init__(self, config: dict, max_seq_length: int = 2000, max_batch_size: int = 10) -> None:
|
| 213 |
+
super().__init__(config, max_seq_length, max_batch_size)
|
| 214 |
+
|
| 215 |
+
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
| 216 |
+
self.ar_predict_layer = nn.Linear(self.hidden_dim, self.vocab_size, bias=False)
|
| 217 |
+
|
| 218 |
+
self.h = TransformerDecoderONNX(
|
| 219 |
+
self.hidden_dim, self.n_layer, self.n_head, self.ffn_dim, self.vocab_size, max_seq_length, max_batch_size
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
def pre_forward(self, session) -> tuple[list[Tensor], dict[str, Tensor]]:
|
| 223 |
+
return super().pre_forward(session)
|
| 224 |
+
|
| 225 |
+
def post_forward(self, idx: int, session) -> None:
|
| 226 |
+
return super().post_forward(idx, session)
|
| 227 |
+
|
| 228 |
+
def embed_onnx_(
|
| 229 |
+
self,
|
| 230 |
+
x: Tensor,
|
| 231 |
+
x_len: Tensor,
|
| 232 |
+
y: torch.Tensor,
|
| 233 |
+
bert_features: Tensor,
|
| 234 |
+
):
|
| 235 |
+
B = x.shape[0]
|
| 236 |
+
D = self.embedding_dim
|
| 237 |
+
T_TOTAL = 500
|
| 238 |
+
xy_pos = torch.zeros((B, T_TOTAL, D)).to(bert_features[0].dtype)
|
| 239 |
+
|
| 240 |
+
bert_features = bert_features.transpose(1, 2)
|
| 241 |
+
|
| 242 |
+
y_len = y.shape[1]
|
| 243 |
+
y_emb = self.ar_audio_embedding(y)
|
| 244 |
+
y_pos = self.ar_audio_position.prefill(y_emb)
|
| 245 |
+
|
| 246 |
+
for bs, x_, len_, bert_feature in zip(torch.arange(x.shape[0]), x, x_len, bert_features):
|
| 247 |
+
x_emb = self.ar_text_embedding(x_[:len_])
|
| 248 |
+
|
| 249 |
+
bert = self.bert_proj(bert_feature[:len_])
|
| 250 |
+
|
| 251 |
+
print(bert.shape, bert_feature[:len_])
|
| 252 |
+
|
| 253 |
+
return bert, bert_feature[:len_].unsqueeze(0)
|
| 254 |
+
|
| 255 |
+
return bert[:20].unsqueeze(0), None
|
| 256 |
+
x_emb = x_emb + bert
|
| 257 |
+
x_pos = self.ar_text_position.prefill(x_emb.unsqueeze(0))
|
| 258 |
+
|
| 259 |
+
xy_pos[None, bs, :len_] = bert
|
| 260 |
+
# xy_pos[None, bs, len_ : len_ + y_len] = y_pos
|
| 261 |
+
|
| 262 |
+
return xy_pos[:, -1], None
|
| 263 |
+
|
| 264 |
+
return xy_pos[: x.shape[0]], x_len
|
| 265 |
+
|
| 266 |
+
def embed_onnx(
|
| 267 |
+
self,
|
| 268 |
+
x: torch.Tensor, # [B, Tx]
|
| 269 |
+
x_len: torch.Tensor, # [B]
|
| 270 |
+
y: torch.Tensor, # [1, Ty, D]
|
| 271 |
+
bert_features: torch.Tensor, # [B, 1024, Tx]
|
| 272 |
+
):
|
| 273 |
+
# [B, 1024, Tx] -> [B, Tx, 1024]
|
| 274 |
+
bert_features = bert_features.transpose(1, 2)
|
| 275 |
+
|
| 276 |
+
Ty = y.shape[1]
|
| 277 |
+
Tx = x.shape[1]
|
| 278 |
+
B = x.shape[0]
|
| 279 |
+
D = self.embedding_dim
|
| 280 |
+
T_TOTAL = 500
|
| 281 |
+
|
| 282 |
+
# mask: [B, Tx],[j] Col < x_len[i]
|
| 283 |
+
col = torch.arange(Tx, device=x.device).unsqueeze(0) # [1, Tx]
|
| 284 |
+
mask_x = col < x_len.view(-1, 1) # [B, Tx]
|
| 285 |
+
mask_x3 = mask_x.unsqueeze(-1) # [B, Tx, 1]
|
| 286 |
+
|
| 287 |
+
torch._check((Ty >= 0) and (Ty <= 250), "y_len out of range")
|
| 288 |
+
torch._check((Tx >= 0) and (Tx <= 250), "x_len out of range")
|
| 289 |
+
|
| 290 |
+
y_emb = self.ar_audio_embedding(y) # [1, Ty, D]
|
| 291 |
+
y_pos = self.ar_audio_position.prefill(y_emb) # [1, Ty, D]
|
| 292 |
+
|
| 293 |
+
x_emb_full = self.ar_text_embedding(x) # [B, Tx, D]
|
| 294 |
+
bert_full = self.bert_proj(bert_features[[0], : x_len[0]]) # [B, Tx, D]
|
| 295 |
+
|
| 296 |
+
print(bert_full[0].shape, bert_features[0, : x_len[0]])
|
| 297 |
+
|
| 298 |
+
return bert_full[0], bert_features[0, : x_len[0]]
|
| 299 |
+
|
| 300 |
+
x_sum_full = x_emb_full + bert_full # [B, Tx, D]
|
| 301 |
+
x_pos_full = self.ar_text_position.prefill(x_sum_full) # [B, Tx, D]
|
| 302 |
+
|
| 303 |
+
xy_pos = torch.zeros((B, T_TOTAL, D), dtype=x_pos_full.dtype, device=x_pos_full.device)
|
| 304 |
+
|
| 305 |
+
xy_pos[:, :Tx, :] = torch.where(
|
| 306 |
+
mask_x3,
|
| 307 |
+
bert_full[:, :Tx, :].to(xy_pos.dtype),
|
| 308 |
+
xy_pos[:, :Tx, :],
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
return xy_pos[:, -1], None
|
| 312 |
+
|
| 313 |
+
# Start From offset=x_len, Ty
|
| 314 |
+
# [Ty] Index: offsets + [0..Ty-1]
|
| 315 |
+
offsets = x_len.clamp(min=0, max=T_TOTAL - Ty) # [B]
|
| 316 |
+
idx_y = offsets.unsqueeze(1) + torch.arange(Ty, device=x_pos_full.device) # [B, Ty]
|
| 317 |
+
# scatter to dim=1
|
| 318 |
+
# expand index to [B, Ty, D]
|
| 319 |
+
idx_y3 = idx_y.unsqueeze(-1).expand(B, Ty, D)
|
| 320 |
+
y_pos_b = y_pos.expand(B, Ty, D).to(xy_pos.dtype) # [B, Ty, D]
|
| 321 |
+
xy_pos = xy_pos.scatter(1, idx_y3, y_pos_b)
|
| 322 |
+
|
| 323 |
+
return xy_pos, x_len
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def torchscript_export(model: T2SDecoderONNX, stage="embed"):
|
| 327 |
+
if stage == "embed":
|
| 328 |
+
x = torch.randint(1, 600, (model.max_batch_size, 50))
|
| 329 |
+
x_len = torch.randint(30, 50, (model.max_batch_size,))
|
| 330 |
+
y = torch.randint(1, 600, (1, 200))
|
| 331 |
+
bert_features = torch.rand((model.max_batch_size, 1024, 50))
|
| 332 |
+
|
| 333 |
+
x_len[-1] = 50
|
| 334 |
+
|
| 335 |
+
mask = torch.arange(x_len.max().item(), device=x.device).unsqueeze(0) < x_len.unsqueeze(1)
|
| 336 |
+
|
| 337 |
+
x = x * mask
|
| 338 |
+
bert_features = bert_features * mask.unsqueeze(1)
|
| 339 |
+
|
| 340 |
+
try:
|
| 341 |
+
a, c = model.embed_onnx_(x, x_len, y, bert_features)
|
| 342 |
+
b, d = model.embed_onnx(x, x_len, y, bert_features)
|
| 343 |
+
print("-" * 20)
|
| 344 |
+
print(a - b, (a - b).sum(), (a - b).square().mean())
|
| 345 |
+
print(c - d, (c - d).sum(), (c - d).square().mean())
|
| 346 |
+
exit()
|
| 347 |
+
assert torch.allclose(a, b, atol=1e-6, rtol=1e-8), (a - b).square().mean()
|
| 348 |
+
|
| 349 |
+
setattr(model, "forward", model.embed_onnx)
|
| 350 |
+
scripted_model = torch.jit.script(model, example_inputs=[(x, x_len, y, bert_features)])
|
| 351 |
+
|
| 352 |
+
onnx_program = torch.onnx.export(
|
| 353 |
+
scripted_model,
|
| 354 |
+
(x, x_len, y, bert_features),
|
| 355 |
+
input_names=["text", "text_len", "prompt", "bert_features"],
|
| 356 |
+
output_names=["xy_pos", "input_pos"],
|
| 357 |
+
dynamic_axes={
|
| 358 |
+
"text": {0: "Batch_Size", 1: "Sequence_Length_X"},
|
| 359 |
+
"prompt": {0: "Batch_Size", 1: "Sequence_Length_Y"},
|
| 360 |
+
"bert_features": {0: "Batch_Size", 1: "Sequence_Length_X"},
|
| 361 |
+
},
|
| 362 |
+
opset_version=21,
|
| 363 |
+
training=False,
|
| 364 |
+
do_constant_folding=True,
|
| 365 |
+
external_data=False,
|
| 366 |
+
)
|
| 367 |
+
assert onnx_program
|
| 368 |
+
onnx_program.save("onnx_export/AR_Embedding_TorchScript.onnx")
|
| 369 |
+
|
| 370 |
+
except Exception:
|
| 371 |
+
logger.bind(show_locals=False).exception("")
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def dynamo_export(model: T2SDecoderONNX, stage="embed"):
|
| 375 |
+
if stage == "embed":
|
| 376 |
+
x = torch.randint(1, 600, (model.max_batch_size, 50))
|
| 377 |
+
x_len = torch.randint(30, 50, (model.max_batch_size,))
|
| 378 |
+
y = torch.randint(1, 600, (1, 200))
|
| 379 |
+
bert_features = torch.rand((model.max_batch_size, 1024, 50))
|
| 380 |
+
|
| 381 |
+
x_len[-1] = 50
|
| 382 |
+
|
| 383 |
+
mask = torch.arange(x_len.max().item(), device=x.device).unsqueeze(0) < x_len.unsqueeze(1)
|
| 384 |
+
|
| 385 |
+
x = x * mask
|
| 386 |
+
bert_features = (bert_features.transpose(1, 2) * mask.unsqueeze(-1)).transpose(1, 2)
|
| 387 |
+
|
| 388 |
+
dynamic_shapes = [
|
| 389 |
+
{
|
| 390 |
+
0: Dim("Batch_Size", min=1, max=4),
|
| 391 |
+
1: Dim("Sequence_Length_X", min=1, max=50),
|
| 392 |
+
},
|
| 393 |
+
{
|
| 394 |
+
0: Dim("Batch_Size", min=1, max=4),
|
| 395 |
+
},
|
| 396 |
+
{
|
| 397 |
+
1: Dim("Sequence_Length_Y", min=1, max=250),
|
| 398 |
+
},
|
| 399 |
+
{
|
| 400 |
+
0: Dim("Batch_Size", min=1, max=4),
|
| 401 |
+
2: Dim("Sequence_Length_X", min=1, max=50),
|
| 402 |
+
},
|
| 403 |
+
]
|
| 404 |
+
try:
|
| 405 |
+
a = model.embed_onnx_(x, x_len, y, bert_features)[0]
|
| 406 |
+
b = model.embed_onnx(x, x_len, y, bert_features)[0]
|
| 407 |
+
print(a - b, (a - b).square().mean())
|
| 408 |
+
exit()
|
| 409 |
+
assert torch.allclose(a, b, atol=1e-6, rtol=1e-8), (a - b).square().mean()
|
| 410 |
+
|
| 411 |
+
setattr(model, "forward", model.embed_onnx)
|
| 412 |
+
onnx_program = torch.onnx.export(
|
| 413 |
+
model,
|
| 414 |
+
(x, x_len, y, bert_features),
|
| 415 |
+
input_names=["text", "text_len", "prompt", "bert_features"],
|
| 416 |
+
output_names=["xy_pos", "input_pos"],
|
| 417 |
+
dynamo=True,
|
| 418 |
+
dynamic_shapes=dynamic_shapes,
|
| 419 |
+
opset_version=21,
|
| 420 |
+
training=False,
|
| 421 |
+
do_constant_folding=True,
|
| 422 |
+
external_data=False,
|
| 423 |
+
)
|
| 424 |
+
assert onnx_program
|
| 425 |
+
onnx_program.save("onnx_export/AR_Embedding_Dynamo.onnx")
|
| 426 |
+
except Exception:
|
| 427 |
+
logger.bind(show_locals=False).exception("")
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
@app.command()
|
| 431 |
+
def export(
|
| 432 |
+
ckpt_path: Path = typer.Option(
|
| 433 |
+
...,
|
| 434 |
+
"--ckpt-path",
|
| 435 |
+
file_okay=True,
|
| 436 |
+
dir_okay=False,
|
| 437 |
+
exists=True,
|
| 438 |
+
readable=True,
|
| 439 |
+
show_default=False,
|
| 440 |
+
help="AR Checkpoint",
|
| 441 |
+
),
|
| 442 |
+
dynamo: bool = typer.Option(False, is_flag=True, flag_value=True, help="Use Torch Dynamo"),
|
| 443 |
+
stages: list[Stage] = typer.Option([Stage.embed], "--stages", help="Stage to export"),
|
| 444 |
+
):
|
| 445 |
+
os.makedirs("onnx_export", exist_ok=True)
|
| 446 |
+
dict_s1 = torch.load(ckpt_path, "cpu", mmap=True)
|
| 447 |
+
condig = dict_s1["config"]
|
| 448 |
+
model = T2SDecoderONNX(condig, 2000, 4)
|
| 449 |
+
state_dict = dict_s1["weight"]
|
| 450 |
+
model.load_state_dict(state_dict)
|
| 451 |
+
|
| 452 |
+
for stage in stages:
|
| 453 |
+
if dynamo:
|
| 454 |
+
dynamo_export(model, stage)
|
| 455 |
+
else:
|
| 456 |
+
torchscript_export(model, stage)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def get_prog_name() -> str:
|
| 460 |
+
script_rel = ".".join(["GPT_SoVITS", "Accelerate", "PyTorch", osp.basename(__file__)]).strip(".py")
|
| 461 |
+
return f"python -s -m {script_rel}"
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
if __name__ == "__main__":
|
| 465 |
+
t = time.perf_counter()
|
| 466 |
+
app(prog_name=get_prog_name())
|
| 467 |
+
logger.info(f"Exec Time: {time.perf_counter() - t:.2f} secs")
|
GPT_SoVITS/Accelerate/PyTorch/nn.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Enhanced Type Hint nn.Module
|
| 3 |
+
Modified From https://github.com/labmlai/labml/blob/master/helpers/labml_helpers/module.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Any
|
| 7 |
+
|
| 8 |
+
import torch.nn
|
| 9 |
+
from torch.nn import (
|
| 10 |
+
functional as functional,
|
| 11 |
+
)
|
| 12 |
+
from torch.nn import (
|
| 13 |
+
utils as utils,
|
| 14 |
+
)
|
| 15 |
+
from torch.nn.modules import * # type: ignore # noqa: F403
|
| 16 |
+
from torch.nn.parameter import (
|
| 17 |
+
Parameter as Parameter,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
Tensor = torch.Tensor
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Module(torch.nn.Module):
|
| 24 |
+
r"""
|
| 25 |
+
Wraps ``torch.nn.Module`` to overload ``__call__`` instead of
|
| 26 |
+
``forward`` for better type checking.
|
| 27 |
+
|
| 28 |
+
`PyTorch Github issue for clarification <https://github.com/pytorch/pytorch/issues/44605>`_
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def _forward_unimplemented(self, *input: Any) -> None:
|
| 32 |
+
# To stop PyTorch from giving abstract methods warning
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
def __init_subclass__(cls, **kwargs):
|
| 36 |
+
if cls.__dict__.get("__call__", None) is None:
|
| 37 |
+
return
|
| 38 |
+
|
| 39 |
+
setattr(cls, "forward", cls.__dict__["__call__"])
|
| 40 |
+
delattr(cls, "__call__")
|
| 41 |
+
|
| 42 |
+
@property
|
| 43 |
+
def device(self) -> torch.device:
|
| 44 |
+
params = self.parameters()
|
| 45 |
+
try:
|
| 46 |
+
sample_param = next(params)
|
| 47 |
+
return sample_param.device
|
| 48 |
+
except StopIteration:
|
| 49 |
+
raise RuntimeError(f"Unable to determine device of {self.__class__.__name__}") from None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class Linear(torch.nn.Linear):
|
| 53 |
+
def __call__(self, input: Tensor) -> Tensor:
|
| 54 |
+
return super().__call__(input)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Dropout(torch.nn.Dropout):
|
| 58 |
+
def __call__(self, input: Tensor) -> Tensor:
|
| 59 |
+
return super().__call__(input)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class Embedding(torch.nn.Embedding):
|
| 63 |
+
def __call__(self, input: Tensor) -> Tensor:
|
| 64 |
+
return super().__call__(input)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class LayerNorm(torch.nn.LayerNorm):
|
| 68 |
+
def __call__(self, input: Tensor) -> Tensor:
|
| 69 |
+
return super().__call__(input)
|
GPT_SoVITS/Accelerate/PyTorch/sample_funcs.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Protocol
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
Tensor = torch.Tensor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SampleProtocol(Protocol):
|
| 10 |
+
@staticmethod
|
| 11 |
+
def __call__(
|
| 12 |
+
logits: Tensor,
|
| 13 |
+
previous_tokens: Tensor,
|
| 14 |
+
temperature: float,
|
| 15 |
+
top_k: int,
|
| 16 |
+
top_p: float,
|
| 17 |
+
repetition_penalty: float,
|
| 18 |
+
) -> Tensor: ...
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class sample_naive(SampleProtocol):
|
| 22 |
+
@staticmethod
|
| 23 |
+
def __call__(
|
| 24 |
+
logits: Tensor,
|
| 25 |
+
previous_tokens: Tensor,
|
| 26 |
+
temperature: float,
|
| 27 |
+
top_k: int,
|
| 28 |
+
top_p: float,
|
| 29 |
+
repetition_penalty: float,
|
| 30 |
+
):
|
| 31 |
+
if temperature <= 1e-5:
|
| 32 |
+
probs = F.softmax(logits, dim=-1)
|
| 33 |
+
return torch.argmax(probs, dim=-1, keepdim=True).to(dtype=torch.int32)
|
| 34 |
+
|
| 35 |
+
if repetition_penalty != 1.0:
|
| 36 |
+
previous_tokens = previous_tokens.long()
|
| 37 |
+
score = torch.gather(logits, dim=1, index=previous_tokens)
|
| 38 |
+
score = torch.where(
|
| 39 |
+
score < 0,
|
| 40 |
+
score * repetition_penalty,
|
| 41 |
+
score / repetition_penalty,
|
| 42 |
+
)
|
| 43 |
+
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
| 44 |
+
|
| 45 |
+
if top_p < 1.0:
|
| 46 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 47 |
+
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 48 |
+
cum_probs[cum_probs > 1] = 1
|
| 49 |
+
sorted_indices_to_remove = cum_probs > top_p
|
| 50 |
+
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
| 51 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
| 52 |
+
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
| 53 |
+
)
|
| 54 |
+
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
| 55 |
+
|
| 56 |
+
if temperature < 1.0:
|
| 57 |
+
logits /= temperature
|
| 58 |
+
|
| 59 |
+
v, _ = torch.topk(logits, top_k)
|
| 60 |
+
pivot = v[:, -1].unsqueeze(-1)
|
| 61 |
+
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
| 62 |
+
|
| 63 |
+
probs = F.softmax(logits, dim=-1)
|
| 64 |
+
q = -torch.log(torch.rand_like(probs))
|
| 65 |
+
idx_next = torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.int32)
|
| 66 |
+
|
| 67 |
+
return idx_next
|
GPT_SoVITS/Accelerate/PyTorch/structs.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Modified From https://github.com/XXXXRT666/GPT-SoVITS
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Literal, MutableSequence, Optional, Protocol
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
from .sample_funcs import SampleProtocol, sample_naive
|
| 13 |
+
|
| 14 |
+
Tensor = torch.Tensor
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class T2SResult:
|
| 19 |
+
result: list[Tensor] | None = None
|
| 20 |
+
infer_speed: tuple[float, float] = (0.0, 0.0)
|
| 21 |
+
status: Literal["Success", "Error"] = "Success"
|
| 22 |
+
exception: Optional[Exception] = None
|
| 23 |
+
traceback: Optional[str] = None
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class T2SRequest:
|
| 28 |
+
x: list[torch.Tensor]
|
| 29 |
+
x_lens: Tensor
|
| 30 |
+
prompts: torch.Tensor
|
| 31 |
+
bert_feature: list[Tensor]
|
| 32 |
+
valid_length: int
|
| 33 |
+
top_k: int = 5
|
| 34 |
+
top_p: float = 1
|
| 35 |
+
early_stop_num: int = -1
|
| 36 |
+
temperature: float = 1.0
|
| 37 |
+
repetition_penalty: float = 1.35
|
| 38 |
+
use_cuda_graph: bool = False
|
| 39 |
+
debug: bool = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class KVCacheProtocol(Protocol):
|
| 43 |
+
k_cache: Tensor
|
| 44 |
+
v_cache: Tensor
|
| 45 |
+
|
| 46 |
+
def __init__(self, batch_size: int, max_seq_length: int, n_heads: int, head_dim: int) -> None: ...
|
| 47 |
+
|
| 48 |
+
def empty(self) -> None: ...
|
| 49 |
+
|
| 50 |
+
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor, *args, **kwds) -> tuple[Tensor, Tensor]: ...
|
| 51 |
+
|
| 52 |
+
def prefill_kv(self, k_val: Tensor, v_val: Tensor) -> None: ...
|
| 53 |
+
|
| 54 |
+
def sync_cache(self, kv_cache: KVCacheProtocol) -> None: ...
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class T2SDecoderProtocol(Protocol):
|
| 58 |
+
max_seq_length: int
|
| 59 |
+
EOS: int
|
| 60 |
+
n_head: int
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def device(self) -> torch.device: ...
|
| 64 |
+
|
| 65 |
+
def embed(self, x: list[Tensor], y: Tensor, bert_features: list[Tensor]) -> Tensor: ...
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class T2SEngineProtocol(Protocol):
|
| 69 |
+
def _handle_request(self, request: T2SRequest) -> tuple[list[Tensor], float, float]: ...
|
| 70 |
+
|
| 71 |
+
def generate(self, request: T2SRequest) -> T2SResult: ...
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class T2SSession:
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
decoder: T2SDecoderProtocol,
|
| 78 |
+
request: T2SRequest,
|
| 79 |
+
sapmle_func: type[SampleProtocol] = sample_naive,
|
| 80 |
+
device: torch.device = torch.device("cpu"),
|
| 81 |
+
dtype: torch.dtype = torch.float32,
|
| 82 |
+
):
|
| 83 |
+
with device:
|
| 84 |
+
self.decoder = decoder
|
| 85 |
+
self.request = request
|
| 86 |
+
self.device = device
|
| 87 |
+
self.dtype = dtype
|
| 88 |
+
|
| 89 |
+
bsz = len(request.x)
|
| 90 |
+
y_len = request.prompts.size(-1)
|
| 91 |
+
self.bsz = bsz
|
| 92 |
+
self.y_len = y_len
|
| 93 |
+
request.prompts = request.prompts.to(device, torch.int32)
|
| 94 |
+
|
| 95 |
+
# Cache
|
| 96 |
+
self.kv_cache: MutableSequence[KVCacheProtocol]
|
| 97 |
+
self.sample = sapmle_func()
|
| 98 |
+
|
| 99 |
+
# Forward args
|
| 100 |
+
self.x = [i.to(device) for i in request.x]
|
| 101 |
+
self.x_lens = request.x_lens.to(torch.int32)
|
| 102 |
+
self.y = torch.zeros((bsz, decoder.max_seq_length)).to(torch.int32)
|
| 103 |
+
self.y[:, : request.prompts.shape[-1]] = request.prompts
|
| 104 |
+
self.bert_feature = [i.to(device, dtype) for i in request.bert_feature]
|
| 105 |
+
|
| 106 |
+
self.prefill_len = self.x_lens + request.prompts.size(1)
|
| 107 |
+
|
| 108 |
+
self.input_pos = torch.zeros_like(self.prefill_len)
|
| 109 |
+
self.input_pos.add_(self.prefill_len)
|
| 110 |
+
|
| 111 |
+
# CUDA Graph
|
| 112 |
+
self.stream: Optional[torch.cuda.Stream] = None
|
| 113 |
+
self.graph: Optional[torch.cuda.CUDAGraph] = None
|
| 114 |
+
self.xy_pos_: Tensor
|
| 115 |
+
self.xy_dec_: Tensor
|
| 116 |
+
|
| 117 |
+
# EOS
|
| 118 |
+
self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
|
| 119 |
+
self.y_results: list[Tensor] = [None] * len(self.x) # type: ignore
|
| 120 |
+
|
| 121 |
+
self.xy_pos = decoder.embed(self.x, request.prompts, self.bert_feature)
|
| 122 |
+
|
| 123 |
+
max_len = int(self.prefill_len.max().item())
|
| 124 |
+
attn_mask = torch.zeros(size=(bsz, max_len, max_len), dtype=torch.bool)
|
| 125 |
+
|
| 126 |
+
for bs in range(bsz):
|
| 127 |
+
pos = int(self.x_lens[bs])
|
| 128 |
+
seq_len = pos + y_len
|
| 129 |
+
|
| 130 |
+
attn_mask[bs, :seq_len, :pos] = True
|
| 131 |
+
|
| 132 |
+
ar_mask = ~torch.triu(
|
| 133 |
+
input=torch.ones(
|
| 134 |
+
size=(
|
| 135 |
+
y_len,
|
| 136 |
+
y_len,
|
| 137 |
+
),
|
| 138 |
+
dtype=torch.bool,
|
| 139 |
+
),
|
| 140 |
+
diagonal=1,
|
| 141 |
+
)
|
| 142 |
+
attn_mask[bs, pos:seq_len, pos:seq_len] = ar_mask
|
| 143 |
+
|
| 144 |
+
self.attn_mask = attn_mask
|
| 145 |
+
self.attn_mask = attn_mask.unsqueeze(0).expand(-1, decoder.n_head, -1, -1)
|
| 146 |
+
|
| 147 |
+
self.id: int = -1
|
| 148 |
+
|
| 149 |
+
# Sage Attn & Transformer Engine Impl
|
| 150 |
+
self.cu_seqlens_q: Tensor
|
| 151 |
+
self.cu_seqlens_kv: Tensor
|
GPT_SoVITS/Accelerate/PyTorch/t2s_engine.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import gc
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import time
|
| 6 |
+
import traceback
|
| 7 |
+
from importlib import import_module
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from rich.progress import BarColumn, Progress, TextColumn
|
| 11 |
+
|
| 12 |
+
from ..logger import SpeedColumnToken, console, logger
|
| 13 |
+
from .structs import T2SEngineProtocol, T2SRequest, T2SResult, T2SSession
|
| 14 |
+
from .t2s_model_abc import (
|
| 15 |
+
CUDAGraphCacheABC,
|
| 16 |
+
T2SDecoderABC,
|
| 17 |
+
TorchProfiler,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class T2SEngine(T2SEngineProtocol):
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
decoder_model: T2SDecoderABC,
|
| 25 |
+
device: torch.device = torch.device("cpu"),
|
| 26 |
+
dtype: torch.dtype = torch.float32,
|
| 27 |
+
) -> None:
|
| 28 |
+
assert device.type in {"cpu", "cuda", "mps", "xpu", "mtia"}
|
| 29 |
+
assert dtype in {torch.float16, torch.bfloat16, torch.float32}
|
| 30 |
+
|
| 31 |
+
self.device = device
|
| 32 |
+
self.dtype = dtype
|
| 33 |
+
|
| 34 |
+
self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
|
| 35 |
+
|
| 36 |
+
self.graphcache: CUDAGraphCacheABC = self.init_cache()
|
| 37 |
+
|
| 38 |
+
def _handle_request(self, request: T2SRequest):
|
| 39 |
+
with self.device:
|
| 40 |
+
decoder = self.decoder_model
|
| 41 |
+
session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
|
| 42 |
+
batch_idx = torch.arange(session.bsz)
|
| 43 |
+
|
| 44 |
+
t1 = 0.0
|
| 45 |
+
infer_speed = 0.0
|
| 46 |
+
infer_time = 0.0
|
| 47 |
+
|
| 48 |
+
torch_profiler = TorchProfiler(request.debug)
|
| 49 |
+
with (
|
| 50 |
+
torch_profiler.profiler(),
|
| 51 |
+
Progress(
|
| 52 |
+
TextColumn("[cyan]{task.description}"),
|
| 53 |
+
BarColumn(),
|
| 54 |
+
TextColumn("{task.completed}/{task.total} tokens"),
|
| 55 |
+
SpeedColumnToken(show_speed=True),
|
| 56 |
+
console=console,
|
| 57 |
+
transient=True,
|
| 58 |
+
) as progress,
|
| 59 |
+
):
|
| 60 |
+
max_token = int(min(2000 - session.input_pos.max(), 1500))
|
| 61 |
+
task = progress.add_task("T2S Decoding", total=max_token)
|
| 62 |
+
|
| 63 |
+
for idx in range(max_token):
|
| 64 |
+
progress.update(task, advance=1)
|
| 65 |
+
if idx == 0:
|
| 66 |
+
session.kv_cache = decoder.init_cache(session.bsz)
|
| 67 |
+
xy_dec = decoder.h.prefill(session.xy_pos, session.kv_cache, session.attn_mask)
|
| 68 |
+
xy_dec = xy_dec[None, batch_idx, session.input_pos - 1]
|
| 69 |
+
else:
|
| 70 |
+
if (
|
| 71 |
+
request.use_cuda_graph
|
| 72 |
+
and session.graph is None
|
| 73 |
+
and self.graphcache.is_applicable
|
| 74 |
+
and torch.cuda.is_available()
|
| 75 |
+
):
|
| 76 |
+
self.graphcache.assign_graph(session)
|
| 77 |
+
|
| 78 |
+
with torch_profiler.record("AR"):
|
| 79 |
+
if session.graph:
|
| 80 |
+
assert session.stream
|
| 81 |
+
session.stream.wait_stream(torch.cuda.default_stream())
|
| 82 |
+
with torch.cuda.stream(session.stream):
|
| 83 |
+
session.xy_pos_.copy_(session.xy_pos)
|
| 84 |
+
session.graph.replay()
|
| 85 |
+
xy_dec = session.xy_dec_.clone()
|
| 86 |
+
else:
|
| 87 |
+
args, kwds = decoder.pre_forward(session)
|
| 88 |
+
xy_dec = decoder.h(
|
| 89 |
+
session.input_pos,
|
| 90 |
+
session.xy_pos,
|
| 91 |
+
session.kv_cache,
|
| 92 |
+
*args,
|
| 93 |
+
**kwds,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
with torch.cuda.stream(session.stream) if session.stream is not None else contextlib.nullcontext():
|
| 97 |
+
decoder.post_forward(idx, session)
|
| 98 |
+
logits = decoder.ar_predict_layer(xy_dec[:, -1])
|
| 99 |
+
|
| 100 |
+
if idx == 0:
|
| 101 |
+
logits[:, -1] = float("-inf")
|
| 102 |
+
|
| 103 |
+
with torch_profiler.record("Sampling"):
|
| 104 |
+
samples = session.sample(
|
| 105 |
+
logits=logits,
|
| 106 |
+
previous_tokens=session.y[:, : session.y_len + idx],
|
| 107 |
+
top_k=request.top_k,
|
| 108 |
+
top_p=request.top_p,
|
| 109 |
+
repetition_penalty=request.repetition_penalty,
|
| 110 |
+
temperature=request.temperature,
|
| 111 |
+
)
|
| 112 |
+
session.y[batch_idx, session.y_len + idx] = samples
|
| 113 |
+
session.input_pos.add_(1)
|
| 114 |
+
|
| 115 |
+
with torch_profiler.record("EOS"):
|
| 116 |
+
argmax_token = torch.argmax(logits, dim=-1)
|
| 117 |
+
sample_token = samples.squeeze(1)
|
| 118 |
+
EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
|
| 119 |
+
|
| 120 |
+
newly_done_mask = EOS_mask & (~session.completed)
|
| 121 |
+
newly_done_indices = newly_done_mask.nonzero()
|
| 122 |
+
|
| 123 |
+
if newly_done_indices.numel() > 0:
|
| 124 |
+
for i in newly_done_indices:
|
| 125 |
+
session.y_results[i] = session.y[i, session.y_len : session.y_len + idx]
|
| 126 |
+
session.completed[newly_done_indices] = True
|
| 127 |
+
|
| 128 |
+
if torch.all(session.completed).item():
|
| 129 |
+
if session.y[:, session.y_len :].sum() == 0:
|
| 130 |
+
session.y_results = [torch.tensor(0) for _ in range(session.bsz)]
|
| 131 |
+
logger.error("Bad Zero Prediction")
|
| 132 |
+
else:
|
| 133 |
+
logger.info(
|
| 134 |
+
f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> {[i.size(-1) for i in session.y_results].__str__().strip('[]')}"
|
| 135 |
+
)
|
| 136 |
+
logger.info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
|
| 137 |
+
infer_time = time.perf_counter() - t1
|
| 138 |
+
infer_speed = (idx - 1) / infer_time
|
| 139 |
+
break
|
| 140 |
+
|
| 141 |
+
if (request.early_stop_num != -1 and idx >= request.early_stop_num) or idx == max_token - 1:
|
| 142 |
+
for i in range(session.bsz):
|
| 143 |
+
if not session.completed[i].item():
|
| 144 |
+
session.y_results[i] = session.y[i, session.y_len : session.y_len + 1499]
|
| 145 |
+
session.completed[i] = True
|
| 146 |
+
logger.error("Bad Full Prediction")
|
| 147 |
+
break
|
| 148 |
+
|
| 149 |
+
with torch_profiler.record("NextPos"):
|
| 150 |
+
y_emb = decoder.ar_audio_embedding(samples)
|
| 151 |
+
session.xy_pos = decoder.ar_audio_position(session.input_pos - session.x_lens, y_emb)
|
| 152 |
+
|
| 153 |
+
if idx == 1:
|
| 154 |
+
torch_profiler.start()
|
| 155 |
+
t1 = time.perf_counter()
|
| 156 |
+
|
| 157 |
+
if idx == 51:
|
| 158 |
+
torch_profiler.end()
|
| 159 |
+
|
| 160 |
+
if idx % 100 == 0:
|
| 161 |
+
match session.device.type:
|
| 162 |
+
case "cuda":
|
| 163 |
+
torch.cuda.empty_cache()
|
| 164 |
+
case "mps":
|
| 165 |
+
torch.mps.empty_cache()
|
| 166 |
+
case "xpu":
|
| 167 |
+
torch.xpu.empty_cache()
|
| 168 |
+
case "mtia":
|
| 169 |
+
torch.mtia.empty_cache()
|
| 170 |
+
|
| 171 |
+
match session.device.type:
|
| 172 |
+
case "cuda":
|
| 173 |
+
if session.stream is not None:
|
| 174 |
+
torch.cuda.current_stream().wait_stream(session.stream)
|
| 175 |
+
torch.cuda.empty_cache()
|
| 176 |
+
case "mps":
|
| 177 |
+
torch.mps.empty_cache()
|
| 178 |
+
case "xpu":
|
| 179 |
+
torch.xpu.empty_cache()
|
| 180 |
+
case "mtia":
|
| 181 |
+
torch.mtia.empty_cache()
|
| 182 |
+
case "cpu":
|
| 183 |
+
gc.collect()
|
| 184 |
+
|
| 185 |
+
torch_profiler.end()
|
| 186 |
+
if request.use_cuda_graph and torch.cuda.is_available():
|
| 187 |
+
self.graphcache.release_graph(session)
|
| 188 |
+
|
| 189 |
+
return session.y_results[: request.valid_length], infer_speed, infer_time
|
| 190 |
+
|
| 191 |
+
def generate(self, request: T2SRequest):
|
| 192 |
+
try:
|
| 193 |
+
result, infer_speed, infer_time = self._handle_request(request)
|
| 194 |
+
t2s_result = T2SResult(result=result, infer_speed=(infer_speed, infer_time), status="Success")
|
| 195 |
+
except Exception as e:
|
| 196 |
+
t2s_result = T2SResult(status="Error", exception=e, traceback=traceback.format_exc())
|
| 197 |
+
return t2s_result
|
| 198 |
+
|
| 199 |
+
@staticmethod
|
| 200 |
+
def load_decoder(weights_path: os.PathLike, max_batch_size: int = 1, backend: str = "Flash-Attn-Varlen-CUDAGraph"):
|
| 201 |
+
logger.info(f"Loading Text2Semantic Weights from {weights_path} with {backend} Backend")
|
| 202 |
+
module_path = f".backends.{backend.lower().replace('-', '_').replace('cudagraph', 'cuda_graph')}"
|
| 203 |
+
decoder_cls_name = "T2SDecoder"
|
| 204 |
+
decoder_mod = import_module(module_path, package=__package__)
|
| 205 |
+
decoder_cls: type[T2SDecoderABC] = getattr(decoder_mod, decoder_cls_name)
|
| 206 |
+
dict_s1 = torch.load(weights_path, map_location="cpu", weights_only=False, mmap=True)
|
| 207 |
+
config = dict_s1["config"]
|
| 208 |
+
decoder: T2SDecoderABC = decoder_cls(config, max_batch_size=max_batch_size)
|
| 209 |
+
state_dict = dict_s1["weight"]
|
| 210 |
+
decoder.load_state_dict(state_dict)
|
| 211 |
+
|
| 212 |
+
return decoder.eval()
|
| 213 |
+
|
| 214 |
+
def init_cache(self):
|
| 215 |
+
assert self.decoder_model
|
| 216 |
+
|
| 217 |
+
module_name = self.decoder_model.__class__.__module__
|
| 218 |
+
module = sys.modules.get(module_name)
|
| 219 |
+
assert module
|
| 220 |
+
|
| 221 |
+
target_class: type[CUDAGraphCacheABC] = getattr(module, "CUDAGraphCache")
|
| 222 |
+
|
| 223 |
+
return target_class(self.decoder_model)
|
GPT_SoVITS/Accelerate/PyTorch/t2s_model_abc.py
ADDED
|
@@ -0,0 +1,672 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Modified From https://github.com/XXXXRT666/GPT-SoVITS
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
import random
|
| 10 |
+
from abc import ABC, abstractmethod
|
| 11 |
+
from contextlib import nullcontext
|
| 12 |
+
from typing import MutableSequence
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch._inductor.config
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from torch.cuda.graphs import CUDAGraph
|
| 18 |
+
from torch.profiler import ProfilerAction, tensorboard_trace_handler
|
| 19 |
+
|
| 20 |
+
from . import nn
|
| 21 |
+
from .structs import KVCacheProtocol, T2SDecoderProtocol, T2SSession
|
| 22 |
+
|
| 23 |
+
Tensor = torch.Tensor
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TokenEmbedding(nn.Module):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
embedding_dim: int,
|
| 30 |
+
vocab_size: int,
|
| 31 |
+
):
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
self.vocab_size = vocab_size
|
| 35 |
+
self.embedding_dim = embedding_dim
|
| 36 |
+
|
| 37 |
+
self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim)
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
def weight(self) -> Tensor:
|
| 41 |
+
return self.word_embeddings.weight
|
| 42 |
+
|
| 43 |
+
def embedding(self, index: int) -> Tensor:
|
| 44 |
+
return self.word_embeddings.weight[index : index + 1]
|
| 45 |
+
|
| 46 |
+
def __call__(self, x: Tensor):
|
| 47 |
+
x = self.word_embeddings(x)
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class SinePositionalEmbedding(nn.Module):
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
embedding_dim: int,
|
| 55 |
+
scale: bool = False,
|
| 56 |
+
alpha: bool = False,
|
| 57 |
+
max_batch_size: int = 10,
|
| 58 |
+
max_seq_len: int = 2000,
|
| 59 |
+
):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.embedding_dim = embedding_dim
|
| 62 |
+
self.x_scale = math.sqrt(embedding_dim) if scale else 1.0
|
| 63 |
+
self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha)
|
| 64 |
+
self.max_batch_size = max_batch_size
|
| 65 |
+
self.max_seq_len = max_seq_len
|
| 66 |
+
|
| 67 |
+
self.reverse = False
|
| 68 |
+
self.register_buffer("pe", torch.zeros(max_batch_size, max_seq_len, embedding_dim), persistent=False)
|
| 69 |
+
self.pe: torch.Tensor
|
| 70 |
+
self.compute_pe()
|
| 71 |
+
|
| 72 |
+
def compute_pe(self):
|
| 73 |
+
"""Reset the positional encodings."""
|
| 74 |
+
if self.reverse:
|
| 75 |
+
position = torch.arange(self.max_seq_len - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
|
| 76 |
+
else:
|
| 77 |
+
position = torch.arange(self.max_seq_len, dtype=torch.float32).unsqueeze(1)
|
| 78 |
+
div_term = torch.exp(
|
| 79 |
+
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
|
| 80 |
+
)
|
| 81 |
+
pe = self.pe
|
| 82 |
+
pe[:, :, 0::2] = torch.sin(position * div_term)
|
| 83 |
+
pe[:, :, 1::2] = torch.cos(position * div_term)
|
| 84 |
+
|
| 85 |
+
def __call__(self, input_pos: Tensor, x: Tensor) -> Tensor:
|
| 86 |
+
"""
|
| 87 |
+
Args:
|
| 88 |
+
input_pos (Tensor): [batch_size, ]
|
| 89 |
+
x (Tensor): [batch_size, 1, embed_dim]
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
embedded_x (Tensor): [batch_size, 1, embed_dim]
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
batch_size = x.shape[0]
|
| 96 |
+
pe_values = self.pe[torch.arange(batch_size), input_pos - 1] # (batch_size, embed_dim)
|
| 97 |
+
|
| 98 |
+
return x * self.x_scale + self.alpha * pe_values.unsqueeze(1) # (batch_size, 1, embed_dim)
|
| 99 |
+
|
| 100 |
+
def prefill(self, x: Tensor) -> Tensor:
|
| 101 |
+
"""
|
| 102 |
+
Args:
|
| 103 |
+
x (Tensor): [batch_size, seq_len, embed_dim]
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
embedded_x (Tensor): [batch_size, seq_len, embed_dim]
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
batch_size = x.shape[0]
|
| 110 |
+
pe_values = self.pe[:batch_size, : x.shape[-2]]
|
| 111 |
+
return x * self.x_scale + self.alpha * pe_values
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class KVCacheABC(nn.Module, ABC, KVCacheProtocol):
|
| 115 |
+
def __init__(self, batch_size: int, max_seq_length: int, n_heads: int, head_dim: int) -> None:
|
| 116 |
+
super().__init__()
|
| 117 |
+
|
| 118 |
+
self.n_head = n_heads
|
| 119 |
+
self.head_dim = head_dim
|
| 120 |
+
self.batch_size = batch_size
|
| 121 |
+
self.max_seq_length = max_seq_length
|
| 122 |
+
|
| 123 |
+
self.k_cache: Tensor
|
| 124 |
+
self.v_cache: Tensor
|
| 125 |
+
|
| 126 |
+
def empty(self):
|
| 127 |
+
self.k_cache.zero_()
|
| 128 |
+
self.v_cache.zero_()
|
| 129 |
+
|
| 130 |
+
@abstractmethod
|
| 131 |
+
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor, *args, **kwds) -> tuple[Tensor, Tensor]: ...
|
| 132 |
+
|
| 133 |
+
@abstractmethod
|
| 134 |
+
def prefill_kv(self, k_val: Tensor, v_val: Tensor) -> None: ...
|
| 135 |
+
|
| 136 |
+
def sync_cache(self, kv_cache: KVCacheProtocol):
|
| 137 |
+
self.k_cache.copy_(kv_cache.k_cache)
|
| 138 |
+
self.v_cache.copy_(kv_cache.v_cache)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class KVCacheNHD(KVCacheABC):
|
| 142 |
+
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
|
| 143 |
+
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
|
| 144 |
+
|
| 145 |
+
assert batch_size > 0
|
| 146 |
+
cache_shape = (batch_size, max_seq_length, n_heads, head_dim)
|
| 147 |
+
|
| 148 |
+
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
|
| 149 |
+
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
|
| 150 |
+
|
| 151 |
+
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
| 152 |
+
# input_pos: [B, ], k_val: [B, 1, H, D]
|
| 153 |
+
|
| 154 |
+
index = (
|
| 155 |
+
(input_pos - 1)
|
| 156 |
+
.unsqueeze(-1)
|
| 157 |
+
.unsqueeze(-1)
|
| 158 |
+
.unsqueeze(-1)
|
| 159 |
+
.expand(
|
| 160 |
+
-1,
|
| 161 |
+
-1,
|
| 162 |
+
self.n_head,
|
| 163 |
+
self.head_dim,
|
| 164 |
+
)
|
| 165 |
+
.to(torch.int64)
|
| 166 |
+
) # (bs, 1, num_head, head_dim)
|
| 167 |
+
|
| 168 |
+
k_out = self.k_cache
|
| 169 |
+
v_out = self.v_cache
|
| 170 |
+
k_out.scatter_(1, index, k_val)
|
| 171 |
+
v_out.scatter_(1, index, v_val)
|
| 172 |
+
|
| 173 |
+
return k_out, v_out
|
| 174 |
+
|
| 175 |
+
def empty(self):
|
| 176 |
+
self.k_cache.zero_()
|
| 177 |
+
self.v_cache.zero_()
|
| 178 |
+
|
| 179 |
+
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
|
| 180 |
+
# input_pos: int, k_val: [B, S, H, D]
|
| 181 |
+
|
| 182 |
+
self.k_cache[:, : k_val.shape[1]] = k_val
|
| 183 |
+
self.v_cache[:, : v_val.shape[1]] = v_val
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class KVCacheHND(KVCacheABC):
|
| 187 |
+
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
|
| 188 |
+
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
|
| 189 |
+
|
| 190 |
+
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
|
| 191 |
+
|
| 192 |
+
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
|
| 193 |
+
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
|
| 194 |
+
|
| 195 |
+
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
| 196 |
+
# input_pos: [B, ], k_val: [B, H, 1, D]
|
| 197 |
+
|
| 198 |
+
index = (
|
| 199 |
+
(input_pos - 1)
|
| 200 |
+
.unsqueeze(-1)
|
| 201 |
+
.unsqueeze(-1)
|
| 202 |
+
.unsqueeze(-1)
|
| 203 |
+
.expand(
|
| 204 |
+
-1,
|
| 205 |
+
self.n_head,
|
| 206 |
+
-1,
|
| 207 |
+
self.head_dim,
|
| 208 |
+
)
|
| 209 |
+
.to(torch.int64)
|
| 210 |
+
) # (bs, num_head, 1, head_dim)
|
| 211 |
+
|
| 212 |
+
k_out = self.k_cache
|
| 213 |
+
v_out = self.v_cache
|
| 214 |
+
k_out.scatter_(2, index, k_val)
|
| 215 |
+
v_out.scatter_(2, index, v_val)
|
| 216 |
+
|
| 217 |
+
return k_out, v_out
|
| 218 |
+
|
| 219 |
+
def empty(self):
|
| 220 |
+
self.k_cache.zero_()
|
| 221 |
+
self.v_cache.zero_()
|
| 222 |
+
|
| 223 |
+
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
|
| 224 |
+
# input_pos: int, k_val: [B, S, H, D]
|
| 225 |
+
|
| 226 |
+
self.k_cache[..., : k_val.shape[1], :] = k_val.transpose(1, 2)
|
| 227 |
+
self.v_cache[..., : v_val.shape[1], :] = v_val.transpose(1, 2)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class KVCacheHNDVarlen(KVCacheABC):
|
| 231 |
+
def __init__(self, batch_size, max_seq_length, n_heads, head_dim):
|
| 232 |
+
super().__init__(batch_size, max_seq_length, n_heads, head_dim)
|
| 233 |
+
|
| 234 |
+
cache_shape = (batch_size, n_heads, max_seq_length, head_dim)
|
| 235 |
+
self.cache_idx: Tensor
|
| 236 |
+
|
| 237 |
+
self.register_buffer("cache_idx", torch.arange(batch_size), persistent=False)
|
| 238 |
+
self.register_buffer("k_cache", torch.zeros(size=cache_shape), persistent=False)
|
| 239 |
+
self.register_buffer("v_cache", torch.zeros(size=cache_shape), persistent=False)
|
| 240 |
+
|
| 241 |
+
def update(self, input_pos: Tensor, k_val: Tensor, v_val: Tensor):
|
| 242 |
+
# input_pos: [B, ], k_val: [B, H, 1, D]
|
| 243 |
+
|
| 244 |
+
k_out = self.k_cache
|
| 245 |
+
v_out = self.v_cache
|
| 246 |
+
|
| 247 |
+
ip0 = input_pos - 1
|
| 248 |
+
|
| 249 |
+
k_out[self.cache_idx, :, ip0, None] = k_val
|
| 250 |
+
v_out[self.cache_idx, :, ip0, None] = v_val
|
| 251 |
+
|
| 252 |
+
return k_out, v_out
|
| 253 |
+
|
| 254 |
+
def empty(self):
|
| 255 |
+
self.k_cache.zero_()
|
| 256 |
+
self.v_cache.zero_()
|
| 257 |
+
|
| 258 |
+
def prefill_kv(self, k_val: Tensor, v_val: Tensor):
|
| 259 |
+
# input_pos: int, k_val: [B, S, H, D]
|
| 260 |
+
|
| 261 |
+
self.k_cache[..., : k_val.shape[1], :] = k_val.transpose(1, 2)
|
| 262 |
+
self.v_cache[..., : v_val.shape[1], :] = v_val.transpose(1, 2)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class AttentionABC(nn.Module, ABC):
|
| 266 |
+
def __init__(self, n_head: int, hidden_dim: int, max_seq_length: int):
|
| 267 |
+
super().__init__()
|
| 268 |
+
|
| 269 |
+
self.n_head = n_head
|
| 270 |
+
self.hidden_dim = hidden_dim
|
| 271 |
+
assert hidden_dim % n_head == 0
|
| 272 |
+
self.head_dim = hidden_dim // n_head
|
| 273 |
+
|
| 274 |
+
self.max_seq_length = max_seq_length
|
| 275 |
+
|
| 276 |
+
# key, query, value projections for all heads, but in a batch
|
| 277 |
+
self.in_proj: nn.Linear
|
| 278 |
+
self.out_proj: nn.Linear
|
| 279 |
+
|
| 280 |
+
self._register_load_state_dict_pre_hook(self.load_hook)
|
| 281 |
+
|
| 282 |
+
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
|
| 283 |
+
keys_to_modify = [key for key in state_dict if "in_proj_" in key]
|
| 284 |
+
for key in keys_to_modify:
|
| 285 |
+
new_key = key.replace("in_proj_", "in_proj.") # in_proj_ -> in_proj.
|
| 286 |
+
state_dict[new_key] = state_dict.pop(key)
|
| 287 |
+
|
| 288 |
+
@abstractmethod
|
| 289 |
+
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds) -> Tensor: ...
|
| 290 |
+
|
| 291 |
+
def prefill(self, x: Tensor, kv_cache: KVCacheProtocol, attn_mask: Tensor) -> Tensor:
|
| 292 |
+
bsz, seqlen, _ = x.shape
|
| 293 |
+
|
| 294 |
+
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
| 295 |
+
|
| 296 |
+
q, k, v = map(lambda x: x.contiguous().view(bsz, seqlen, self.n_head, self.head_dim), (q, k, v))
|
| 297 |
+
|
| 298 |
+
kv_cache.prefill_kv(k, v)
|
| 299 |
+
|
| 300 |
+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
|
| 301 |
+
|
| 302 |
+
attn = F.scaled_dot_product_attention(q, k, v, attn_mask)
|
| 303 |
+
|
| 304 |
+
attn = attn.transpose(1, 2).contiguous().view(1, -1, self.hidden_dim)
|
| 305 |
+
|
| 306 |
+
output = self.out_proj(attn)
|
| 307 |
+
|
| 308 |
+
return output
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class FeedForward(nn.Module):
|
| 312 |
+
def __init__(self, dim: int, hidden_dim: int) -> None:
|
| 313 |
+
super().__init__()
|
| 314 |
+
|
| 315 |
+
self.linear1 = nn.Linear(dim, hidden_dim, bias=True)
|
| 316 |
+
self.linear2 = nn.Linear(hidden_dim, dim, bias=True)
|
| 317 |
+
|
| 318 |
+
def __call__(self, x: Tensor):
|
| 319 |
+
return self.linear2(F.relu(self.linear1(x), inplace=True))
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class TransformerBlockABC(nn.Module, ABC):
|
| 323 |
+
def __init__(self, n_head: int, ffn_dim: int, hidden_dim: int, max_seq_length: int) -> None:
|
| 324 |
+
super().__init__()
|
| 325 |
+
|
| 326 |
+
self.hidden_dim = hidden_dim
|
| 327 |
+
self.max_seq_length = max_seq_length
|
| 328 |
+
|
| 329 |
+
self.attention: AttentionABC
|
| 330 |
+
self.feed_forward: FeedForward
|
| 331 |
+
self.attention_norm: nn.LayerNorm
|
| 332 |
+
self.ffn_norm: nn.LayerNorm
|
| 333 |
+
|
| 334 |
+
self._register_load_state_dict_pre_hook(self.load_hook)
|
| 335 |
+
|
| 336 |
+
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
|
| 337 |
+
for key in list(state_dict.keys()):
|
| 338 |
+
new_key = (
|
| 339 |
+
key.replace("self_attn", "attention")
|
| 340 |
+
.replace("linear", "feed_forward.linear")
|
| 341 |
+
.replace("norm1", "attention_norm")
|
| 342 |
+
.replace("norm2", "ffn_norm")
|
| 343 |
+
)
|
| 344 |
+
state_dict[new_key] = state_dict.pop(key)
|
| 345 |
+
|
| 346 |
+
def __call__(self, x: Tensor, input_pos: Tensor, kv_cache: KVCacheProtocol, *args, **kwds):
|
| 347 |
+
h = self.attention_norm(
|
| 348 |
+
x
|
| 349 |
+
+ self.attention(
|
| 350 |
+
x,
|
| 351 |
+
input_pos,
|
| 352 |
+
kv_cache,
|
| 353 |
+
*args,
|
| 354 |
+
**kwds,
|
| 355 |
+
)
|
| 356 |
+
)
|
| 357 |
+
out = self.ffn_norm(h + self.feed_forward(h))
|
| 358 |
+
return out
|
| 359 |
+
|
| 360 |
+
def prefill(
|
| 361 |
+
self,
|
| 362 |
+
x: Tensor,
|
| 363 |
+
kv_cache: KVCacheProtocol,
|
| 364 |
+
attn_mask: Tensor,
|
| 365 |
+
) -> Tensor:
|
| 366 |
+
h = self.attention_norm(
|
| 367 |
+
x
|
| 368 |
+
+ self.attention.prefill(
|
| 369 |
+
x,
|
| 370 |
+
kv_cache,
|
| 371 |
+
attn_mask,
|
| 372 |
+
)
|
| 373 |
+
)
|
| 374 |
+
out = self.ffn_norm(h + self.feed_forward(h))
|
| 375 |
+
return out
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class TransformerDecoderABC(nn.Module, ABC):
|
| 379 |
+
def __init__(
|
| 380 |
+
self,
|
| 381 |
+
hidden_dim: int,
|
| 382 |
+
n_layer: int,
|
| 383 |
+
n_head: int,
|
| 384 |
+
ffn_dim: int,
|
| 385 |
+
vocab_size: int,
|
| 386 |
+
max_seq_length: int,
|
| 387 |
+
max_batch_size: int,
|
| 388 |
+
) -> None:
|
| 389 |
+
super().__init__()
|
| 390 |
+
|
| 391 |
+
self.hidden_dim = hidden_dim
|
| 392 |
+
self.n_head = n_head
|
| 393 |
+
assert hidden_dim % n_head == 0
|
| 394 |
+
|
| 395 |
+
self.head_dim = hidden_dim // n_head
|
| 396 |
+
self.vocab_size = vocab_size
|
| 397 |
+
|
| 398 |
+
self.n_layer = n_layer
|
| 399 |
+
|
| 400 |
+
self.layers: MutableSequence[TransformerBlockABC]
|
| 401 |
+
|
| 402 |
+
self.max_seq_length = max_seq_length
|
| 403 |
+
self.max_batch_size = max_batch_size
|
| 404 |
+
|
| 405 |
+
def __call__(self, input_pos: Tensor, x: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds):
|
| 406 |
+
for layer, kv_cache in zip(self.layers, kv_caches):
|
| 407 |
+
x = layer(x, input_pos, kv_cache, *args, **kwds)
|
| 408 |
+
return x
|
| 409 |
+
|
| 410 |
+
def prefill(self, x: Tensor, kv_caches: MutableSequence[KVCacheProtocol], attn_mask: Tensor):
|
| 411 |
+
for layer, kv_cache in zip(self.layers, kv_caches):
|
| 412 |
+
x = layer.prefill(x, kv_cache, attn_mask)
|
| 413 |
+
return x
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
class T2SDecoderABC(nn.Module, ABC, T2SDecoderProtocol):
|
| 417 |
+
def __init__(
|
| 418 |
+
self,
|
| 419 |
+
config: dict,
|
| 420 |
+
max_seq_length: int = 2000,
|
| 421 |
+
max_batch_size: int = 10,
|
| 422 |
+
) -> None:
|
| 423 |
+
super().__init__()
|
| 424 |
+
|
| 425 |
+
hidden_dim: int = config["model"]["hidden_dim"]
|
| 426 |
+
embedding_dim: int = config["model"]["embedding_dim"]
|
| 427 |
+
n_head: int = config["model"]["head"]
|
| 428 |
+
n_layer: int = config["model"]["n_layer"]
|
| 429 |
+
vocab_size: int = config["model"]["vocab_size"]
|
| 430 |
+
phoneme_vocab_size: int = config["model"]["phoneme_vocab_size"]
|
| 431 |
+
EOS: int = config["model"]["EOS"]
|
| 432 |
+
ffn_dim: int = hidden_dim * 4
|
| 433 |
+
|
| 434 |
+
self.n_layer = int(n_layer)
|
| 435 |
+
self.hidden_dim = int(hidden_dim)
|
| 436 |
+
self.n_head = int(n_head)
|
| 437 |
+
assert hidden_dim % n_head == 0
|
| 438 |
+
|
| 439 |
+
self.head_dim = int(hidden_dim // n_head)
|
| 440 |
+
self.embedding_dim = int(embedding_dim)
|
| 441 |
+
self.ffn_dim = int(ffn_dim)
|
| 442 |
+
self.vocab_size = int(vocab_size)
|
| 443 |
+
self.phoneme_vocab_size = int(phoneme_vocab_size)
|
| 444 |
+
self.max_seq_length = max_seq_length
|
| 445 |
+
self.max_batch_size = max_batch_size
|
| 446 |
+
self.EOS = EOS
|
| 447 |
+
assert self.EOS == self.vocab_size - 1
|
| 448 |
+
|
| 449 |
+
self.bert_proj: nn.Linear
|
| 450 |
+
self.ar_predict_layer: nn.Linear
|
| 451 |
+
self.h: TransformerDecoderABC
|
| 452 |
+
|
| 453 |
+
self.kv_class: type[KVCacheABC]
|
| 454 |
+
|
| 455 |
+
self.GraphCache: CUDAGraphCacheABC | None
|
| 456 |
+
|
| 457 |
+
self.ar_text_embedding = TokenEmbedding(self.embedding_dim, self.phoneme_vocab_size)
|
| 458 |
+
self.ar_text_position = SinePositionalEmbedding(
|
| 459 |
+
self.embedding_dim,
|
| 460 |
+
scale=False,
|
| 461 |
+
alpha=True,
|
| 462 |
+
max_batch_size=max_batch_size,
|
| 463 |
+
max_seq_len=max_seq_length,
|
| 464 |
+
)
|
| 465 |
+
self.ar_audio_embedding = TokenEmbedding(self.embedding_dim, self.vocab_size)
|
| 466 |
+
self.ar_audio_position = SinePositionalEmbedding(
|
| 467 |
+
self.embedding_dim,
|
| 468 |
+
scale=False,
|
| 469 |
+
alpha=True,
|
| 470 |
+
max_batch_size=max_batch_size,
|
| 471 |
+
max_seq_len=max_seq_length,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
self._register_load_state_dict_pre_hook(self.load_hook)
|
| 475 |
+
|
| 476 |
+
def load_hook(self, state_dict: dict[str, Tensor], prefix, *args):
|
| 477 |
+
model_keys = [key for key in state_dict if key.startswith("model.")]
|
| 478 |
+
for key in model_keys:
|
| 479 |
+
new_key = key[len("model.") :]
|
| 480 |
+
state_dict[new_key] = state_dict.pop(key)
|
| 481 |
+
|
| 482 |
+
def init_cache(self, bsz: int = 0) -> MutableSequence[KVCacheProtocol]:
|
| 483 |
+
bsz = bsz or self.h.max_batch_size
|
| 484 |
+
assert bsz <= self.h.max_batch_size
|
| 485 |
+
seq_lens = self.h.max_seq_length
|
| 486 |
+
dtype = self.bert_proj.bias.dtype
|
| 487 |
+
kvclass = self.kv_class
|
| 488 |
+
|
| 489 |
+
return nn.ModuleList(
|
| 490 |
+
[kvclass(bsz, seq_lens, self.n_head, self.head_dim) for _ in range(self.n_layer)],
|
| 491 |
+
).to(self.device, dtype) # type: ignore
|
| 492 |
+
|
| 493 |
+
def embed(
|
| 494 |
+
self,
|
| 495 |
+
x: list[torch.Tensor],
|
| 496 |
+
y: torch.Tensor,
|
| 497 |
+
bert_features: list[torch.Tensor],
|
| 498 |
+
):
|
| 499 |
+
x_len: list[int] = [i.shape[0] for i in x]
|
| 500 |
+
x_len_max = max(x_len)
|
| 501 |
+
xy_pos = torch.zeros((len(x), x_len_max + y.shape[1], self.embedding_dim)).to(bert_features[0].dtype)
|
| 502 |
+
|
| 503 |
+
bert_features = list(map(lambda x: x.transpose(0, 1), bert_features))
|
| 504 |
+
|
| 505 |
+
y_len = y.shape[1]
|
| 506 |
+
y_emb = self.ar_audio_embedding(y)
|
| 507 |
+
y_pos = self.ar_audio_position.prefill(y_emb)
|
| 508 |
+
|
| 509 |
+
for bs, (x_, len_, bert_feature) in enumerate(zip(x, x_len, bert_features)):
|
| 510 |
+
x_emb = self.ar_text_embedding(x_)
|
| 511 |
+
bert = self.bert_proj(bert_feature)
|
| 512 |
+
x_emb = x_emb + bert
|
| 513 |
+
x_pos = self.ar_text_position.prefill(x_emb.unsqueeze(0))
|
| 514 |
+
xy_pos[[bs], :len_] = x_pos
|
| 515 |
+
xy_pos[[bs], len_ : len_ + y_len] = y_pos
|
| 516 |
+
|
| 517 |
+
return xy_pos
|
| 518 |
+
|
| 519 |
+
def compile(self, *args, **kwds):
|
| 520 |
+
# Experimental features to reduce compilation times, will be on by default in future
|
| 521 |
+
torch._inductor.config.triton.cudagraph_skip_dynamic_graphs = True
|
| 522 |
+
torch._inductor.config.coordinate_descent_tuning = True
|
| 523 |
+
torch._inductor.config.triton.unique_kernel_names = True
|
| 524 |
+
torch._inductor.config.fx_graph_cache = True
|
| 525 |
+
torch._inductor.config.triton.cudagraph_trees = True
|
| 526 |
+
torch._inductor.config.triton.cudagraph_support_input_mutation = True
|
| 527 |
+
self.h.compile(fullgraph=True, mode="reduce-overhead")
|
| 528 |
+
|
| 529 |
+
def capture(
|
| 530 |
+
self, input_pos: Tensor, x: Tensor, x_dec: Tensor, kv_caches: MutableSequence[KVCacheProtocol], *args, **kwds
|
| 531 |
+
) -> CUDAGraph:
|
| 532 |
+
assert torch.cuda.is_available()
|
| 533 |
+
s = torch.cuda.Stream()
|
| 534 |
+
s.wait_stream(torch.cuda.current_stream())
|
| 535 |
+
|
| 536 |
+
graph = torch.cuda.CUDAGraph()
|
| 537 |
+
|
| 538 |
+
with torch.cuda.stream(s):
|
| 539 |
+
for _ in range(5):
|
| 540 |
+
self.h(input_pos, x, kv_caches, *args, **kwds)
|
| 541 |
+
torch.cuda.current_stream().wait_stream(s)
|
| 542 |
+
|
| 543 |
+
with torch.cuda.graph(graph):
|
| 544 |
+
x_dec.copy_(self.h(input_pos, x, kv_caches, *args, **kwds))
|
| 545 |
+
torch.cuda.synchronize()
|
| 546 |
+
|
| 547 |
+
return graph
|
| 548 |
+
|
| 549 |
+
@abstractmethod
|
| 550 |
+
def pre_forward(self, session: T2SSession) -> tuple[list[Tensor], dict[str, Tensor]]:
|
| 551 |
+
return list(), dict()
|
| 552 |
+
|
| 553 |
+
@abstractmethod
|
| 554 |
+
def post_forward(self, idx: int, session: T2SSession) -> None:
|
| 555 |
+
return
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
class CUDAGraphCacheABC(ABC):
|
| 559 |
+
def __init__(
|
| 560 |
+
self,
|
| 561 |
+
decoder: T2SDecoderABC,
|
| 562 |
+
) -> None:
|
| 563 |
+
self.is_applicable: bool
|
| 564 |
+
|
| 565 |
+
if torch.cuda.is_available() and self.is_applicable:
|
| 566 |
+
self.device: torch.device = decoder.device
|
| 567 |
+
self.dtype = decoder.bert_proj.bias.dtype
|
| 568 |
+
|
| 569 |
+
self.assigned: bool = False
|
| 570 |
+
|
| 571 |
+
self.decoder: T2SDecoderABC = decoder
|
| 572 |
+
self.kv_cache: MutableSequence[KVCacheProtocol] = decoder.init_cache(decoder.max_batch_size)
|
| 573 |
+
self.xy_pos = torch.rand(size=(decoder.max_batch_size, 1, decoder.embedding_dim), device=self.device).to(
|
| 574 |
+
self.dtype
|
| 575 |
+
)
|
| 576 |
+
self.xy_dec = self.xy_pos.clone()
|
| 577 |
+
|
| 578 |
+
self.input_pos = torch.tensor([10] * decoder.max_batch_size, device=self.device).int()
|
| 579 |
+
self.graph: torch.cuda.CUDAGraph | None = None
|
| 580 |
+
self.stream: torch.cuda.Stream | None
|
| 581 |
+
|
| 582 |
+
self.id: int = random.randint(1, 2**32 - 1)
|
| 583 |
+
|
| 584 |
+
def assign_graph(self, session: T2SSession):
|
| 585 |
+
if self.graph is None:
|
| 586 |
+
args, kwds = self.decoder.pre_forward(session)
|
| 587 |
+
graph = self.decoder.capture(self.input_pos, self.xy_pos, self.xy_dec, self.kv_cache, *args, **kwds)
|
| 588 |
+
self.graph = graph
|
| 589 |
+
self.stream = torch.cuda.Stream()
|
| 590 |
+
|
| 591 |
+
if self.assigned is False:
|
| 592 |
+
self.get_cache_graph(session)
|
| 593 |
+
session.id = self.id
|
| 594 |
+
self.assigned = True
|
| 595 |
+
else:
|
| 596 |
+
self.capture_new_graph(session)
|
| 597 |
+
|
| 598 |
+
@abstractmethod
|
| 599 |
+
def release_graph(self, session: T2SSession): ...
|
| 600 |
+
|
| 601 |
+
@abstractmethod
|
| 602 |
+
def get_cache_graph(self, session: T2SSession):
|
| 603 |
+
pass
|
| 604 |
+
|
| 605 |
+
@abstractmethod
|
| 606 |
+
def capture_new_graph(self, session: T2SSession):
|
| 607 |
+
pass
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
class TorchProfiler:
|
| 611 |
+
def __init__(self, debug: bool, log_dir: str = "./profiler") -> None:
|
| 612 |
+
self.debug = debug
|
| 613 |
+
self.log_dir = log_dir
|
| 614 |
+
self.__profiler: torch.profiler.profile
|
| 615 |
+
|
| 616 |
+
if self.debug and not os.path.exists(self.log_dir):
|
| 617 |
+
os.makedirs(self.log_dir)
|
| 618 |
+
|
| 619 |
+
self.tensorboard_handler = tensorboard_trace_handler(self.log_dir)
|
| 620 |
+
|
| 621 |
+
def profiler_callback(self, prof: torch.profiler.profile):
|
| 622 |
+
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30))
|
| 623 |
+
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=30))
|
| 624 |
+
self.tensorboard_handler(prof)
|
| 625 |
+
|
| 626 |
+
@staticmethod
|
| 627 |
+
def three_step_schedule(step: int) -> ProfilerAction:
|
| 628 |
+
if step == 0:
|
| 629 |
+
return ProfilerAction.NONE
|
| 630 |
+
elif step == 1:
|
| 631 |
+
return ProfilerAction.RECORD
|
| 632 |
+
elif step == 2:
|
| 633 |
+
return ProfilerAction.RECORD_AND_SAVE
|
| 634 |
+
else:
|
| 635 |
+
return ProfilerAction.NONE
|
| 636 |
+
|
| 637 |
+
def start(self):
|
| 638 |
+
if not self.debug:
|
| 639 |
+
return
|
| 640 |
+
assert self.__profiler is not None
|
| 641 |
+
self.__profiler.step()
|
| 642 |
+
|
| 643 |
+
def end(self):
|
| 644 |
+
if not self.debug:
|
| 645 |
+
return
|
| 646 |
+
assert self.__profiler is not None
|
| 647 |
+
self.__profiler.step()
|
| 648 |
+
|
| 649 |
+
def profiler(self):
|
| 650 |
+
if self.debug:
|
| 651 |
+
activities_list = [torch.profiler.ProfilerActivity.CPU]
|
| 652 |
+
if torch.cuda.is_available():
|
| 653 |
+
activities_list.append(torch.profiler.ProfilerActivity.CUDA)
|
| 654 |
+
|
| 655 |
+
self.__profiler = torch.profiler.profile(
|
| 656 |
+
activities=activities_list,
|
| 657 |
+
record_shapes=True,
|
| 658 |
+
with_stack=True,
|
| 659 |
+
with_modules=True,
|
| 660 |
+
profile_memory=True,
|
| 661 |
+
schedule=self.three_step_schedule,
|
| 662 |
+
on_trace_ready=self.profiler_callback,
|
| 663 |
+
)
|
| 664 |
+
return self.__profiler
|
| 665 |
+
else:
|
| 666 |
+
return nullcontext()
|
| 667 |
+
|
| 668 |
+
def record(self, func_name: str):
|
| 669 |
+
if self.debug:
|
| 670 |
+
return torch.profiler.record_function(func_name)
|
| 671 |
+
else:
|
| 672 |
+
return nullcontext()
|
GPT_SoVITS/Accelerate/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import MLX, PyTorch
|
| 2 |
+
from .logger import console, logger, tb
|
| 3 |
+
from .PyTorch import T2SEngineTorch, T2SRequest, T2SResult
|
| 4 |
+
from .PyTorch.structs import T2SEngineProtocol
|
| 5 |
+
|
| 6 |
+
backends = PyTorch.backends + MLX.backends
|
| 7 |
+
|
| 8 |
+
backends = [
|
| 9 |
+
b.replace("_", "-")
|
| 10 |
+
.title()
|
| 11 |
+
.replace("Mlx", "MLX")
|
| 12 |
+
.replace("Mps", "MPS")
|
| 13 |
+
.replace("Cuda", "CUDA")
|
| 14 |
+
.replace("Mxfp4", "MXFP4")
|
| 15 |
+
for b in backends
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"T2SEngineTorch",
|
| 21 |
+
"T2SRequest",
|
| 22 |
+
"T2SResult",
|
| 23 |
+
"backends",
|
| 24 |
+
"MLX",
|
| 25 |
+
"PyTorch",
|
| 26 |
+
"logger",
|
| 27 |
+
"console",
|
| 28 |
+
"tb",
|
| 29 |
+
"T2SEngineProtocol",
|
| 30 |
+
]
|
GPT_SoVITS/Accelerate/logger.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
from loguru import logger
|
| 5 |
+
from rich.console import Console, JustifyMethod
|
| 6 |
+
from rich.highlighter import Highlighter
|
| 7 |
+
from rich.logging import RichHandler
|
| 8 |
+
from rich.progress import Task, TextColumn
|
| 9 |
+
from rich.style import StyleType
|
| 10 |
+
from rich.table import Column
|
| 11 |
+
from rich.text import Text
|
| 12 |
+
from rich.traceback import Traceback, install
|
| 13 |
+
|
| 14 |
+
console = Console(stderr=False)
|
| 15 |
+
install(console=console)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def loguru_format(record):
|
| 19 |
+
level = record["level"].name
|
| 20 |
+
color = {
|
| 21 |
+
"DEBUG": "green",
|
| 22 |
+
"INFO": "blue",
|
| 23 |
+
"WARNING": "yellow",
|
| 24 |
+
"ERROR": "red",
|
| 25 |
+
"CRITICAL": "bright_red",
|
| 26 |
+
}.get(level, "white")
|
| 27 |
+
|
| 28 |
+
return f"[bold {color}][{level}][/bold {color}] " + "{message}"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
handler_with_locals = RichHandler(
|
| 32 |
+
console=console,
|
| 33 |
+
show_time=False,
|
| 34 |
+
show_path=False,
|
| 35 |
+
rich_tracebacks=True,
|
| 36 |
+
tracebacks_show_locals=True,
|
| 37 |
+
show_level=False,
|
| 38 |
+
markup=True,
|
| 39 |
+
)
|
| 40 |
+
handler_without_locals = RichHandler(
|
| 41 |
+
console=console,
|
| 42 |
+
show_time=False,
|
| 43 |
+
show_path=False,
|
| 44 |
+
rich_tracebacks=True,
|
| 45 |
+
tracebacks_show_locals=False,
|
| 46 |
+
show_level=False,
|
| 47 |
+
markup=True,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def local_filter(r):
|
| 52 |
+
return r["extra"].get("show_locals", True)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
logger.remove()
|
| 56 |
+
logger.add(handler_with_locals, format=loguru_format, filter=local_filter)
|
| 57 |
+
logger.add(handler_without_locals, format=loguru_format, filter=lambda x: not local_filter(x))
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class SpeedColumnToken(TextColumn):
|
| 61 |
+
"""Show task progress as a percentage.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
text_format (str, optional): Format for percentage display. Defaults to "[progress.percentage]{task.percentage:>3.0f}%".
|
| 65 |
+
text_format_no_percentage (str, optional): Format if percentage is unknown. Defaults to "".
|
| 66 |
+
style (StyleType, optional): Style of output. Defaults to "none".
|
| 67 |
+
justify (JustifyMethod, optional): Text justification. Defaults to "left".
|
| 68 |
+
markup (bool, optional): Enable markup. Defaults to True.
|
| 69 |
+
highlighter (Optional[Highlighter], optional): Highlighter to apply to output. Defaults to None.
|
| 70 |
+
table_column (Optional[Column], optional): Table Column to use. Defaults to None.
|
| 71 |
+
show_speed (bool, optional): Show speed if total is unknown. Defaults to False.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
text_format: str = "[progress.percentage]{task.percentage:>3.0f}%",
|
| 77 |
+
text_format_no_percentage: str = "",
|
| 78 |
+
style: StyleType = "none",
|
| 79 |
+
justify: JustifyMethod = "left",
|
| 80 |
+
markup: bool = True,
|
| 81 |
+
highlighter: Optional[Highlighter] = None,
|
| 82 |
+
table_column: Optional[Column] = None,
|
| 83 |
+
show_speed: bool = True,
|
| 84 |
+
) -> None:
|
| 85 |
+
self.text_format_no_percentage = text_format_no_percentage
|
| 86 |
+
self.show_speed = show_speed
|
| 87 |
+
super().__init__(
|
| 88 |
+
text_format=text_format,
|
| 89 |
+
style=style,
|
| 90 |
+
justify=justify,
|
| 91 |
+
markup=markup,
|
| 92 |
+
highlighter=highlighter,
|
| 93 |
+
table_column=table_column,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
@classmethod
|
| 97 |
+
def render_speed(cls, speed: Optional[float]) -> Text:
|
| 98 |
+
"""Render the speed in iterations per second.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
task (Task): A Task object.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
Text: Text object containing the task speed.
|
| 105 |
+
"""
|
| 106 |
+
if speed is None:
|
| 107 |
+
return Text("", style="progress.percentage")
|
| 108 |
+
return Text(f"{speed:.1f} token/s", style="progress.percentage")
|
| 109 |
+
|
| 110 |
+
def render(self, task: Task) -> Text:
|
| 111 |
+
if self.show_speed:
|
| 112 |
+
return self.render_speed(task.finished_speed or task.speed)
|
| 113 |
+
text_format = self.text_format_no_percentage if task.total is None else self.text_format
|
| 114 |
+
_text = text_format.format(task=task)
|
| 115 |
+
if self.markup:
|
| 116 |
+
text = Text.from_markup(_text, style=self.style, justify=self.justify)
|
| 117 |
+
else:
|
| 118 |
+
text = Text(_text, style=self.style, justify=self.justify)
|
| 119 |
+
if self.highlighter:
|
| 120 |
+
self.highlighter.highlight(text)
|
| 121 |
+
return text
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class SpeedColumnIteration(TextColumn):
|
| 125 |
+
"""Show task progress as a percentage.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
text_format (str, optional): Format for percentage display. Defaults to "[progress.percentage]{task.percentage:>3.0f}%".
|
| 129 |
+
text_format_no_percentage (str, optional): Format if percentage is unknown. Defaults to "".
|
| 130 |
+
style (StyleType, optional): Style of output. Defaults to "none".
|
| 131 |
+
justify (JustifyMethod, optional): Text justification. Defaults to "left".
|
| 132 |
+
markup (bool, optional): Enable markup. Defaults to True.
|
| 133 |
+
highlighter (Optional[Highlighter], optional): Highlighter to apply to output. Defaults to None.
|
| 134 |
+
table_column (Optional[Column], optional): Table Column to use. Defaults to None.
|
| 135 |
+
show_speed (bool, optional): Show speed if total is unknown. Defaults to False.
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
def __init__(
|
| 139 |
+
self,
|
| 140 |
+
text_format: str = "[progress.percentage]{task.percentage:>3.0f}%",
|
| 141 |
+
text_format_no_percentage: str = "",
|
| 142 |
+
style: StyleType = "none",
|
| 143 |
+
justify: JustifyMethod = "left",
|
| 144 |
+
markup: bool = True,
|
| 145 |
+
highlighter: Optional[Highlighter] = None,
|
| 146 |
+
table_column: Optional[Column] = None,
|
| 147 |
+
show_speed: bool = True,
|
| 148 |
+
) -> None:
|
| 149 |
+
self.text_format_no_percentage = text_format_no_percentage
|
| 150 |
+
self.show_speed = show_speed
|
| 151 |
+
super().__init__(
|
| 152 |
+
text_format=text_format,
|
| 153 |
+
style=style,
|
| 154 |
+
justify=justify,
|
| 155 |
+
markup=markup,
|
| 156 |
+
highlighter=highlighter,
|
| 157 |
+
table_column=table_column,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
@classmethod
|
| 161 |
+
def render_speed(cls, speed: Optional[float]) -> Text:
|
| 162 |
+
"""Render the speed in iterations per second.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
task (Task): A Task object.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
Text: Text object containing the task speed.
|
| 169 |
+
"""
|
| 170 |
+
if speed is None:
|
| 171 |
+
return Text("", style="progress.percentage")
|
| 172 |
+
return Text(f"{speed:.1f} it/s", style="progress.percentage")
|
| 173 |
+
|
| 174 |
+
def render(self, task: Task) -> Text:
|
| 175 |
+
if self.show_speed:
|
| 176 |
+
return self.render_speed(task.finished_speed or task.speed)
|
| 177 |
+
text_format = self.text_format_no_percentage if task.total is None else self.text_format
|
| 178 |
+
_text = text_format.format(task=task)
|
| 179 |
+
if self.markup:
|
| 180 |
+
text = Text.from_markup(_text, style=self.style, justify=self.justify)
|
| 181 |
+
else:
|
| 182 |
+
text = Text(_text, style=self.style, justify=self.justify)
|
| 183 |
+
if self.highlighter:
|
| 184 |
+
self.highlighter.highlight(text)
|
| 185 |
+
return text
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def tb(show_locals: bool = True):
|
| 189 |
+
exc_type, exc_value, exc_tb = sys.exc_info()
|
| 190 |
+
assert exc_type
|
| 191 |
+
assert exc_value
|
| 192 |
+
tb = Traceback.from_exception(exc_type, exc_value, exc_tb, show_locals=show_locals)
|
| 193 |
+
|
| 194 |
+
return tb
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
__all__ = ["logger", "console", "tb", "SpeedColumnToken", "SpeedColumnIteration"]
|
| 198 |
+
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
try:
|
| 201 |
+
raise RuntimeError()
|
| 202 |
+
except Exception:
|
| 203 |
+
logger.bind(show_locals=False).exception("TEST")
|
GPT_SoVITS/configs/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.yaml
|
GPT_SoVITS/configs/s2.json
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"train": {
|
| 3 |
+
"log_interval": 100,
|
| 4 |
+
"eval_interval": 500,
|
| 5 |
+
"seed": 1234,
|
| 6 |
+
"epochs": 100,
|
| 7 |
+
"learning_rate": 0.0001,
|
| 8 |
+
"betas": [
|
| 9 |
+
0.8,
|
| 10 |
+
0.99
|
| 11 |
+
],
|
| 12 |
+
"eps": 1e-09,
|
| 13 |
+
"batch_size": 32,
|
| 14 |
+
"fp16_run": true,
|
| 15 |
+
"lr_decay": 0.999875,
|
| 16 |
+
"segment_size": 20480,
|
| 17 |
+
"init_lr_ratio": 1,
|
| 18 |
+
"warmup_epochs": 0,
|
| 19 |
+
"c_mel": 45,
|
| 20 |
+
"c_kl": 1.0,
|
| 21 |
+
"text_low_lr_rate": 0.4,
|
| 22 |
+
"grad_ckpt": false
|
| 23 |
+
},
|
| 24 |
+
"data": {
|
| 25 |
+
"max_wav_value": 32768.0,
|
| 26 |
+
"sampling_rate": 32000,
|
| 27 |
+
"filter_length": 2048,
|
| 28 |
+
"hop_length": 640,
|
| 29 |
+
"win_length": 2048,
|
| 30 |
+
"n_mel_channels": 128,
|
| 31 |
+
"mel_fmin": 0.0,
|
| 32 |
+
"mel_fmax": null,
|
| 33 |
+
"add_blank": true,
|
| 34 |
+
"n_speakers": 300,
|
| 35 |
+
"cleaned_text": true
|
| 36 |
+
},
|
| 37 |
+
"model": {
|
| 38 |
+
"inter_channels": 192,
|
| 39 |
+
"hidden_channels": 192,
|
| 40 |
+
"filter_channels": 768,
|
| 41 |
+
"n_heads": 2,
|
| 42 |
+
"n_layers": 6,
|
| 43 |
+
"kernel_size": 3,
|
| 44 |
+
"p_dropout": 0.1,
|
| 45 |
+
"resblock": "1",
|
| 46 |
+
"resblock_kernel_sizes": [
|
| 47 |
+
3,
|
| 48 |
+
7,
|
| 49 |
+
11
|
| 50 |
+
],
|
| 51 |
+
"resblock_dilation_sizes": [
|
| 52 |
+
[
|
| 53 |
+
1,
|
| 54 |
+
3,
|
| 55 |
+
5
|
| 56 |
+
],
|
| 57 |
+
[
|
| 58 |
+
1,
|
| 59 |
+
3,
|
| 60 |
+
5
|
| 61 |
+
],
|
| 62 |
+
[
|
| 63 |
+
1,
|
| 64 |
+
3,
|
| 65 |
+
5
|
| 66 |
+
]
|
| 67 |
+
],
|
| 68 |
+
"upsample_rates": [
|
| 69 |
+
10,
|
| 70 |
+
8,
|
| 71 |
+
2,
|
| 72 |
+
2,
|
| 73 |
+
2
|
| 74 |
+
],
|
| 75 |
+
"upsample_initial_channel": 512,
|
| 76 |
+
"upsample_kernel_sizes": [
|
| 77 |
+
16,
|
| 78 |
+
16,
|
| 79 |
+
8,
|
| 80 |
+
2,
|
| 81 |
+
2
|
| 82 |
+
],
|
| 83 |
+
"n_layers_q": 3,
|
| 84 |
+
"use_spectral_norm": false,
|
| 85 |
+
"gin_channels": 512,
|
| 86 |
+
"semantic_frame_rate": "25hz",
|
| 87 |
+
"freeze_quantizer": true
|
| 88 |
+
},
|
| 89 |
+
"s2_ckpt_dir": "logs/s2/big2k1",
|
| 90 |
+
"content_module": "cnhubert"
|
| 91 |
+
}
|
GPT_SoVITS/configs/s2v2Pro.json
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"train": {
|
| 3 |
+
"log_interval": 100,
|
| 4 |
+
"eval_interval": 500,
|
| 5 |
+
"seed": 1234,
|
| 6 |
+
"epochs": 100,
|
| 7 |
+
"learning_rate": 0.0001,
|
| 8 |
+
"betas": [
|
| 9 |
+
0.8,
|
| 10 |
+
0.99
|
| 11 |
+
],
|
| 12 |
+
"eps": 1e-09,
|
| 13 |
+
"batch_size": 32,
|
| 14 |
+
"fp16_run": true,
|
| 15 |
+
"lr_decay": 0.999875,
|
| 16 |
+
"segment_size": 20480,
|
| 17 |
+
"init_lr_ratio": 1,
|
| 18 |
+
"warmup_epochs": 0,
|
| 19 |
+
"c_mel": 45,
|
| 20 |
+
"c_kl": 1.0,
|
| 21 |
+
"text_low_lr_rate": 0.4,
|
| 22 |
+
"grad_ckpt": false
|
| 23 |
+
},
|
| 24 |
+
"data": {
|
| 25 |
+
"max_wav_value": 32768.0,
|
| 26 |
+
"sampling_rate": 32000,
|
| 27 |
+
"filter_length": 2048,
|
| 28 |
+
"hop_length": 640,
|
| 29 |
+
"win_length": 2048,
|
| 30 |
+
"n_mel_channels": 128,
|
| 31 |
+
"mel_fmin": 0.0,
|
| 32 |
+
"mel_fmax": null,
|
| 33 |
+
"add_blank": true,
|
| 34 |
+
"n_speakers": 300,
|
| 35 |
+
"cleaned_text": true
|
| 36 |
+
},
|
| 37 |
+
"model": {
|
| 38 |
+
"inter_channels": 192,
|
| 39 |
+
"hidden_channels": 192,
|
| 40 |
+
"filter_channels": 768,
|
| 41 |
+
"n_heads": 2,
|
| 42 |
+
"n_layers": 6,
|
| 43 |
+
"kernel_size": 3,
|
| 44 |
+
"p_dropout": 0.0,
|
| 45 |
+
"resblock": "1",
|
| 46 |
+
"resblock_kernel_sizes": [
|
| 47 |
+
3,
|
| 48 |
+
7,
|
| 49 |
+
11
|
| 50 |
+
],
|
| 51 |
+
"resblock_dilation_sizes": [
|
| 52 |
+
[
|
| 53 |
+
1,
|
| 54 |
+
3,
|
| 55 |
+
5
|
| 56 |
+
],
|
| 57 |
+
[
|
| 58 |
+
1,
|
| 59 |
+
3,
|
| 60 |
+
5
|
| 61 |
+
],
|
| 62 |
+
[
|
| 63 |
+
1,
|
| 64 |
+
3,
|
| 65 |
+
5
|
| 66 |
+
]
|
| 67 |
+
],
|
| 68 |
+
"upsample_rates": [
|
| 69 |
+
10,
|
| 70 |
+
8,
|
| 71 |
+
2,
|
| 72 |
+
2,
|
| 73 |
+
2
|
| 74 |
+
],
|
| 75 |
+
"upsample_initial_channel": 512,
|
| 76 |
+
"upsample_kernel_sizes": [
|
| 77 |
+
16,
|
| 78 |
+
16,
|
| 79 |
+
8,
|
| 80 |
+
2,
|
| 81 |
+
2
|
| 82 |
+
],
|
| 83 |
+
"n_layers_q": 3,
|
| 84 |
+
"use_spectral_norm": false,
|
| 85 |
+
"gin_channels": 1024,
|
| 86 |
+
"semantic_frame_rate": "25hz",
|
| 87 |
+
"freeze_quantizer": true
|
| 88 |
+
},
|
| 89 |
+
"s2_ckpt_dir": "logs/s2/big2k1",
|
| 90 |
+
"content_module": "cnhubert"
|
| 91 |
+
}
|
GPT_SoVITS/configs/s2v2ProPlus.json
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"train": {
|
| 3 |
+
"log_interval": 100,
|
| 4 |
+
"eval_interval": 500,
|
| 5 |
+
"seed": 1234,
|
| 6 |
+
"epochs": 100,
|
| 7 |
+
"learning_rate": 0.0001,
|
| 8 |
+
"betas": [
|
| 9 |
+
0.8,
|
| 10 |
+
0.99
|
| 11 |
+
],
|
| 12 |
+
"eps": 1e-09,
|
| 13 |
+
"batch_size": 32,
|
| 14 |
+
"fp16_run": true,
|
| 15 |
+
"lr_decay": 0.999875,
|
| 16 |
+
"segment_size": 20480,
|
| 17 |
+
"init_lr_ratio": 1,
|
| 18 |
+
"warmup_epochs": 0,
|
| 19 |
+
"c_mel": 45,
|
| 20 |
+
"c_kl": 1.0,
|
| 21 |
+
"text_low_lr_rate": 0.4,
|
| 22 |
+
"grad_ckpt": false
|
| 23 |
+
},
|
| 24 |
+
"data": {
|
| 25 |
+
"max_wav_value": 32768.0,
|
| 26 |
+
"sampling_rate": 32000,
|
| 27 |
+
"filter_length": 2048,
|
| 28 |
+
"hop_length": 640,
|
| 29 |
+
"win_length": 2048,
|
| 30 |
+
"n_mel_channels": 128,
|
| 31 |
+
"mel_fmin": 0.0,
|
| 32 |
+
"mel_fmax": null,
|
| 33 |
+
"add_blank": true,
|
| 34 |
+
"n_speakers": 300,
|
| 35 |
+
"cleaned_text": true
|
| 36 |
+
},
|
| 37 |
+
"model": {
|
| 38 |
+
"inter_channels": 192,
|
| 39 |
+
"hidden_channels": 192,
|
| 40 |
+
"filter_channels": 768,
|
| 41 |
+
"n_heads": 2,
|
| 42 |
+
"n_layers": 6,
|
| 43 |
+
"kernel_size": 3,
|
| 44 |
+
"p_dropout": 0.0,
|
| 45 |
+
"resblock": "1",
|
| 46 |
+
"resblock_kernel_sizes": [
|
| 47 |
+
3,
|
| 48 |
+
7,
|
| 49 |
+
11
|
| 50 |
+
],
|
| 51 |
+
"resblock_dilation_sizes": [
|
| 52 |
+
[
|
| 53 |
+
1,
|
| 54 |
+
3,
|
| 55 |
+
5
|
| 56 |
+
],
|
| 57 |
+
[
|
| 58 |
+
1,
|
| 59 |
+
3,
|
| 60 |
+
5
|
| 61 |
+
],
|
| 62 |
+
[
|
| 63 |
+
1,
|
| 64 |
+
3,
|
| 65 |
+
5
|
| 66 |
+
]
|
| 67 |
+
],
|
| 68 |
+
"upsample_rates": [
|
| 69 |
+
10,
|
| 70 |
+
8,
|
| 71 |
+
2,
|
| 72 |
+
2,
|
| 73 |
+
2
|
| 74 |
+
],
|
| 75 |
+
"upsample_initial_channel": 768,
|
| 76 |
+
"upsample_kernel_sizes": [
|
| 77 |
+
20,
|
| 78 |
+
16,
|
| 79 |
+
8,
|
| 80 |
+
2,
|
| 81 |
+
2
|
| 82 |
+
],
|
| 83 |
+
"n_layers_q": 3,
|
| 84 |
+
"use_spectral_norm": false,
|
| 85 |
+
"gin_channels": 1024,
|
| 86 |
+
"semantic_frame_rate": "25hz",
|
| 87 |
+
"freeze_quantizer": true
|
| 88 |
+
},
|
| 89 |
+
"s2_ckpt_dir": "logs/s2/big2k1",
|
| 90 |
+
"content_module": "cnhubert"
|
| 91 |
+
}
|
GPT_SoVITS/eres2net/ERes2NetV2.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
To further improve the short-duration feature extraction capability of ERes2Net, we expand the channel dimension
|
| 6 |
+
within each stage. However, this modification also increases the number of model parameters and computational complexity.
|
| 7 |
+
To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundant structures, ultimately reducing
|
| 8 |
+
both the model parameters and its computational cost.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import math
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
from . import pooling_layers as pooling_layers
|
| 18 |
+
from .fusion import AFF
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ReLU(nn.Hardtanh):
|
| 22 |
+
def __init__(self, inplace=False):
|
| 23 |
+
super(ReLU, self).__init__(0, 20, inplace)
|
| 24 |
+
|
| 25 |
+
def __repr__(self):
|
| 26 |
+
inplace_str = "inplace" if self.inplace else ""
|
| 27 |
+
return self.__class__.__name__ + " (" + inplace_str + ")"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class BasicBlockERes2NetV2(nn.Module):
|
| 31 |
+
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
|
| 32 |
+
super(BasicBlockERes2NetV2, self).__init__()
|
| 33 |
+
width = int(math.floor(planes * (baseWidth / 64.0)))
|
| 34 |
+
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
| 35 |
+
self.bn1 = nn.BatchNorm2d(width * scale)
|
| 36 |
+
self.nums = scale
|
| 37 |
+
self.expansion = expansion
|
| 38 |
+
|
| 39 |
+
convs = []
|
| 40 |
+
bns = []
|
| 41 |
+
for i in range(self.nums):
|
| 42 |
+
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
| 43 |
+
bns.append(nn.BatchNorm2d(width))
|
| 44 |
+
self.convs = nn.ModuleList(convs)
|
| 45 |
+
self.bns = nn.ModuleList(bns)
|
| 46 |
+
self.relu = ReLU(inplace=True)
|
| 47 |
+
|
| 48 |
+
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
| 49 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 50 |
+
self.shortcut = nn.Sequential()
|
| 51 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 52 |
+
self.shortcut = nn.Sequential(
|
| 53 |
+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
| 54 |
+
nn.BatchNorm2d(self.expansion * planes),
|
| 55 |
+
)
|
| 56 |
+
self.stride = stride
|
| 57 |
+
self.width = width
|
| 58 |
+
self.scale = scale
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
residual = x
|
| 62 |
+
|
| 63 |
+
out = self.conv1(x)
|
| 64 |
+
out = self.bn1(out)
|
| 65 |
+
out = self.relu(out)
|
| 66 |
+
spx = torch.split(out, self.width, 1)
|
| 67 |
+
for i in range(self.nums):
|
| 68 |
+
if i == 0:
|
| 69 |
+
sp = spx[i]
|
| 70 |
+
else:
|
| 71 |
+
sp = sp + spx[i]
|
| 72 |
+
sp = self.convs[i](sp)
|
| 73 |
+
sp = self.relu(self.bns[i](sp))
|
| 74 |
+
if i == 0:
|
| 75 |
+
out = sp
|
| 76 |
+
else:
|
| 77 |
+
out = torch.cat((out, sp), 1)
|
| 78 |
+
|
| 79 |
+
out = self.conv3(out)
|
| 80 |
+
out = self.bn3(out)
|
| 81 |
+
|
| 82 |
+
residual = self.shortcut(x)
|
| 83 |
+
out += residual
|
| 84 |
+
out = self.relu(out)
|
| 85 |
+
|
| 86 |
+
return out
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class BasicBlockERes2NetV2AFF(nn.Module):
|
| 90 |
+
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
|
| 91 |
+
super(BasicBlockERes2NetV2AFF, self).__init__()
|
| 92 |
+
width = int(math.floor(planes * (baseWidth / 64.0)))
|
| 93 |
+
self.conv1 = nn.Conv2d(in_planes, width * scale, kernel_size=1, stride=stride, bias=False)
|
| 94 |
+
self.bn1 = nn.BatchNorm2d(width * scale)
|
| 95 |
+
self.nums = scale
|
| 96 |
+
self.expansion = expansion
|
| 97 |
+
|
| 98 |
+
convs = []
|
| 99 |
+
fuse_models = []
|
| 100 |
+
bns = []
|
| 101 |
+
for i in range(self.nums):
|
| 102 |
+
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
| 103 |
+
bns.append(nn.BatchNorm2d(width))
|
| 104 |
+
for j in range(self.nums - 1):
|
| 105 |
+
fuse_models.append(AFF(channels=width, r=4))
|
| 106 |
+
|
| 107 |
+
self.convs = nn.ModuleList(convs)
|
| 108 |
+
self.bns = nn.ModuleList(bns)
|
| 109 |
+
self.fuse_models = nn.ModuleList(fuse_models)
|
| 110 |
+
self.relu = ReLU(inplace=True)
|
| 111 |
+
|
| 112 |
+
self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False)
|
| 113 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 114 |
+
self.shortcut = nn.Sequential()
|
| 115 |
+
if stride != 1 or in_planes != self.expansion * planes:
|
| 116 |
+
self.shortcut = nn.Sequential(
|
| 117 |
+
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
| 118 |
+
nn.BatchNorm2d(self.expansion * planes),
|
| 119 |
+
)
|
| 120 |
+
self.stride = stride
|
| 121 |
+
self.width = width
|
| 122 |
+
self.scale = scale
|
| 123 |
+
|
| 124 |
+
def forward(self, x):
|
| 125 |
+
residual = x
|
| 126 |
+
|
| 127 |
+
out = self.conv1(x)
|
| 128 |
+
out = self.bn1(out)
|
| 129 |
+
out = self.relu(out)
|
| 130 |
+
spx = torch.split(out, self.width, 1)
|
| 131 |
+
for i in range(self.nums):
|
| 132 |
+
if i == 0:
|
| 133 |
+
sp = spx[i]
|
| 134 |
+
else:
|
| 135 |
+
sp = self.fuse_models[i - 1](sp, spx[i])
|
| 136 |
+
|
| 137 |
+
sp = self.convs[i](sp)
|
| 138 |
+
sp = self.relu(self.bns[i](sp))
|
| 139 |
+
if i == 0:
|
| 140 |
+
out = sp
|
| 141 |
+
else:
|
| 142 |
+
out = torch.cat((out, sp), 1)
|
| 143 |
+
|
| 144 |
+
out = self.conv3(out)
|
| 145 |
+
out = self.bn3(out)
|
| 146 |
+
|
| 147 |
+
residual = self.shortcut(x)
|
| 148 |
+
out += residual
|
| 149 |
+
out = self.relu(out)
|
| 150 |
+
|
| 151 |
+
return out
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class ERes2NetV2(nn.Module):
|
| 155 |
+
def __init__(
|
| 156 |
+
self,
|
| 157 |
+
block=BasicBlockERes2NetV2,
|
| 158 |
+
block_fuse=BasicBlockERes2NetV2AFF,
|
| 159 |
+
num_blocks=[3, 4, 6, 3],
|
| 160 |
+
m_channels=64,
|
| 161 |
+
feat_dim=80,
|
| 162 |
+
embedding_size=192,
|
| 163 |
+
baseWidth=26,
|
| 164 |
+
scale=2,
|
| 165 |
+
expansion=2,
|
| 166 |
+
pooling_func="TSTP",
|
| 167 |
+
two_emb_layer=False,
|
| 168 |
+
):
|
| 169 |
+
super(ERes2NetV2, self).__init__()
|
| 170 |
+
self.in_planes = m_channels
|
| 171 |
+
self.feat_dim = feat_dim
|
| 172 |
+
self.embedding_size = embedding_size
|
| 173 |
+
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
| 174 |
+
self.two_emb_layer = two_emb_layer
|
| 175 |
+
self.baseWidth = baseWidth
|
| 176 |
+
self.scale = scale
|
| 177 |
+
self.expansion = expansion
|
| 178 |
+
|
| 179 |
+
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
| 180 |
+
self.bn1 = nn.BatchNorm2d(m_channels)
|
| 181 |
+
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
|
| 182 |
+
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
|
| 183 |
+
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
|
| 184 |
+
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
| 185 |
+
|
| 186 |
+
# Downsampling module
|
| 187 |
+
self.layer3_ds = nn.Conv2d(
|
| 188 |
+
m_channels * 4 * self.expansion,
|
| 189 |
+
m_channels * 8 * self.expansion,
|
| 190 |
+
kernel_size=3,
|
| 191 |
+
padding=1,
|
| 192 |
+
stride=2,
|
| 193 |
+
bias=False,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Bottom-up fusion module
|
| 197 |
+
self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
|
| 198 |
+
|
| 199 |
+
self.n_stats = 1 if pooling_func == "TAP" or pooling_func == "TSDP" else 2
|
| 200 |
+
self.pool = getattr(pooling_layers, pooling_func)(in_dim=self.stats_dim * self.expansion)
|
| 201 |
+
self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats, embedding_size)
|
| 202 |
+
if self.two_emb_layer:
|
| 203 |
+
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
| 204 |
+
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
| 205 |
+
else:
|
| 206 |
+
self.seg_bn_1 = nn.Identity()
|
| 207 |
+
self.seg_2 = nn.Identity()
|
| 208 |
+
|
| 209 |
+
def _make_layer(self, block, planes, num_blocks, stride):
|
| 210 |
+
strides = [stride] + [1] * (num_blocks - 1)
|
| 211 |
+
layers = []
|
| 212 |
+
for stride in strides:
|
| 213 |
+
layers.append(
|
| 214 |
+
block(
|
| 215 |
+
self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion
|
| 216 |
+
)
|
| 217 |
+
)
|
| 218 |
+
self.in_planes = planes * self.expansion
|
| 219 |
+
return nn.Sequential(*layers)
|
| 220 |
+
|
| 221 |
+
def forward(self, x):
|
| 222 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
| 223 |
+
x = x.unsqueeze_(1)
|
| 224 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 225 |
+
out1 = self.layer1(out)
|
| 226 |
+
out2 = self.layer2(out1)
|
| 227 |
+
out3 = self.layer3(out2)
|
| 228 |
+
out4 = self.layer4(out3)
|
| 229 |
+
out3_ds = self.layer3_ds(out3)
|
| 230 |
+
fuse_out34 = self.fuse34(out4, out3_ds)
|
| 231 |
+
stats = self.pool(fuse_out34)
|
| 232 |
+
|
| 233 |
+
embed_a = self.seg_1(stats)
|
| 234 |
+
if self.two_emb_layer:
|
| 235 |
+
out = F.relu(embed_a)
|
| 236 |
+
out = self.seg_bn_1(out)
|
| 237 |
+
embed_b = self.seg_2(out)
|
| 238 |
+
return embed_b
|
| 239 |
+
else:
|
| 240 |
+
return embed_a
|
| 241 |
+
|
| 242 |
+
def forward3(self, x):
|
| 243 |
+
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
| 244 |
+
x = x.unsqueeze_(1)
|
| 245 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 246 |
+
out1 = self.layer1(out)
|
| 247 |
+
out2 = self.layer2(out1)
|
| 248 |
+
out3 = self.layer3(out2)
|
| 249 |
+
out4 = self.layer4(out3)
|
| 250 |
+
out3_ds = self.layer3_ds(out3)
|
| 251 |
+
fuse_out34 = self.fuse34(out4, out3_ds)
|
| 252 |
+
return fuse_out34.flatten(start_dim=1, end_dim=2).mean(-1)
|
GPT_SoVITS/eres2net/fusion.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class AFF(nn.Module):
|
| 9 |
+
def __init__(self, channels=64, r=4):
|
| 10 |
+
super(AFF, self).__init__()
|
| 11 |
+
inter_channels = int(channels // r)
|
| 12 |
+
|
| 13 |
+
self.local_att = nn.Sequential(
|
| 14 |
+
nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
|
| 15 |
+
nn.BatchNorm2d(inter_channels),
|
| 16 |
+
nn.SiLU(inplace=True),
|
| 17 |
+
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
| 18 |
+
nn.BatchNorm2d(channels),
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def forward(self, x, ds_y):
|
| 22 |
+
xa = torch.cat((x, ds_y), dim=1)
|
| 23 |
+
x_att = self.local_att(xa)
|
| 24 |
+
x_att = 1.0 + torch.tanh(x_att)
|
| 25 |
+
xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0 - x_att)
|
| 26 |
+
|
| 27 |
+
return xo
|
GPT_SoVITS/eres2net/kaldi.py
ADDED
|
@@ -0,0 +1,844 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torchaudio
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"get_mel_banks",
|
| 10 |
+
"inverse_mel_scale",
|
| 11 |
+
"inverse_mel_scale_scalar",
|
| 12 |
+
"mel_scale",
|
| 13 |
+
"mel_scale_scalar",
|
| 14 |
+
"spectrogram",
|
| 15 |
+
"fbank",
|
| 16 |
+
"mfcc",
|
| 17 |
+
"vtln_warp_freq",
|
| 18 |
+
"vtln_warp_mel_freq",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
# numeric_limits<float>::epsilon() 1.1920928955078125e-07
|
| 22 |
+
EPSILON = torch.tensor(torch.finfo(torch.float).eps)
|
| 23 |
+
# 1 milliseconds = 0.001 seconds
|
| 24 |
+
MILLISECONDS_TO_SECONDS = 0.001
|
| 25 |
+
|
| 26 |
+
# window types
|
| 27 |
+
HAMMING = "hamming"
|
| 28 |
+
HANNING = "hanning"
|
| 29 |
+
POVEY = "povey"
|
| 30 |
+
RECTANGULAR = "rectangular"
|
| 31 |
+
BLACKMAN = "blackman"
|
| 32 |
+
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _get_epsilon(device, dtype):
|
| 36 |
+
return EPSILON.to(device=device, dtype=dtype)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _next_power_of_2(x: int) -> int:
|
| 40 |
+
r"""Returns the smallest power of 2 that is greater than x"""
|
| 41 |
+
return 1 if x == 0 else 2 ** (x - 1).bit_length()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
|
| 45 |
+
r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
|
| 46 |
+
representing how the window is shifted along the waveform. Each row is a frame.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
waveform (Tensor): Tensor of size ``num_samples``
|
| 50 |
+
window_size (int): Frame length
|
| 51 |
+
window_shift (int): Frame shift
|
| 52 |
+
snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
|
| 53 |
+
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
| 54 |
+
depends only on the frame_shift, and we reflect the data at the ends.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
|
| 58 |
+
"""
|
| 59 |
+
assert waveform.dim() == 1
|
| 60 |
+
num_samples = waveform.size(0)
|
| 61 |
+
strides = (window_shift * waveform.stride(0), waveform.stride(0))
|
| 62 |
+
|
| 63 |
+
if snip_edges:
|
| 64 |
+
if num_samples < window_size:
|
| 65 |
+
return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
|
| 66 |
+
else:
|
| 67 |
+
m = 1 + (num_samples - window_size) // window_shift
|
| 68 |
+
else:
|
| 69 |
+
reversed_waveform = torch.flip(waveform, [0])
|
| 70 |
+
m = (num_samples + (window_shift // 2)) // window_shift
|
| 71 |
+
pad = window_size // 2 - window_shift // 2
|
| 72 |
+
pad_right = reversed_waveform
|
| 73 |
+
if pad > 0:
|
| 74 |
+
# torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
|
| 75 |
+
# but we want [2, 1, 0, 0, 1, 2]
|
| 76 |
+
pad_left = reversed_waveform[-pad:]
|
| 77 |
+
waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
|
| 78 |
+
else:
|
| 79 |
+
# pad is negative so we want to trim the waveform at the front
|
| 80 |
+
waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
|
| 81 |
+
|
| 82 |
+
sizes = (m, window_size)
|
| 83 |
+
return waveform.as_strided(sizes, strides)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _feature_window_function(
|
| 87 |
+
window_type: str,
|
| 88 |
+
window_size: int,
|
| 89 |
+
blackman_coeff: float,
|
| 90 |
+
device: torch.device,
|
| 91 |
+
dtype: int,
|
| 92 |
+
) -> Tensor:
|
| 93 |
+
r"""Returns a window function with the given type and size"""
|
| 94 |
+
if window_type == HANNING:
|
| 95 |
+
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
|
| 96 |
+
elif window_type == HAMMING:
|
| 97 |
+
return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
|
| 98 |
+
elif window_type == POVEY:
|
| 99 |
+
# like hanning but goes to zero at edges
|
| 100 |
+
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
|
| 101 |
+
elif window_type == RECTANGULAR:
|
| 102 |
+
return torch.ones(window_size, device=device, dtype=dtype)
|
| 103 |
+
elif window_type == BLACKMAN:
|
| 104 |
+
a = 2 * math.pi / (window_size - 1)
|
| 105 |
+
window_function = torch.arange(window_size, device=device, dtype=dtype)
|
| 106 |
+
# can't use torch.blackman_window as they use different coefficients
|
| 107 |
+
return (
|
| 108 |
+
blackman_coeff
|
| 109 |
+
- 0.5 * torch.cos(a * window_function)
|
| 110 |
+
+ (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
|
| 111 |
+
).to(device=device, dtype=dtype)
|
| 112 |
+
else:
|
| 113 |
+
raise Exception("Invalid window type " + window_type)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
|
| 117 |
+
r"""Returns the log energy of size (m) for a strided_input (m,*)"""
|
| 118 |
+
device, dtype = strided_input.device, strided_input.dtype
|
| 119 |
+
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
|
| 120 |
+
if energy_floor == 0.0:
|
| 121 |
+
return log_energy
|
| 122 |
+
return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _get_waveform_and_window_properties(
|
| 126 |
+
waveform: Tensor,
|
| 127 |
+
channel: int,
|
| 128 |
+
sample_frequency: float,
|
| 129 |
+
frame_shift: float,
|
| 130 |
+
frame_length: float,
|
| 131 |
+
round_to_power_of_two: bool,
|
| 132 |
+
preemphasis_coefficient: float,
|
| 133 |
+
) -> Tuple[Tensor, int, int, int]:
|
| 134 |
+
r"""Gets the waveform and window properties"""
|
| 135 |
+
channel = max(channel, 0)
|
| 136 |
+
assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
|
| 137 |
+
waveform = waveform[channel, :] # size (n)
|
| 138 |
+
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
|
| 139 |
+
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
|
| 140 |
+
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
|
| 141 |
+
|
| 142 |
+
assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
|
| 143 |
+
window_size, len(waveform)
|
| 144 |
+
)
|
| 145 |
+
assert 0 < window_shift, "`window_shift` must be greater than 0"
|
| 146 |
+
assert padded_window_size % 2 == 0, (
|
| 147 |
+
"the padded `window_size` must be divisible by two. use `round_to_power_of_two` or change `frame_length`"
|
| 148 |
+
)
|
| 149 |
+
assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
|
| 150 |
+
assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
|
| 151 |
+
return waveform, window_shift, window_size, padded_window_size
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def _get_window(
|
| 155 |
+
waveform: Tensor,
|
| 156 |
+
padded_window_size: int,
|
| 157 |
+
window_size: int,
|
| 158 |
+
window_shift: int,
|
| 159 |
+
window_type: str,
|
| 160 |
+
blackman_coeff: float,
|
| 161 |
+
snip_edges: bool,
|
| 162 |
+
raw_energy: bool,
|
| 163 |
+
energy_floor: float,
|
| 164 |
+
dither: float,
|
| 165 |
+
remove_dc_offset: bool,
|
| 166 |
+
preemphasis_coefficient: float,
|
| 167 |
+
) -> Tuple[Tensor, Tensor]:
|
| 168 |
+
r"""Gets a window and its log energy
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
(Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
|
| 172 |
+
"""
|
| 173 |
+
device, dtype = waveform.device, waveform.dtype
|
| 174 |
+
epsilon = _get_epsilon(device, dtype)
|
| 175 |
+
|
| 176 |
+
# size (m, window_size)
|
| 177 |
+
strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
|
| 178 |
+
|
| 179 |
+
if dither != 0.0:
|
| 180 |
+
rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)
|
| 181 |
+
strided_input = strided_input + rand_gauss * dither
|
| 182 |
+
|
| 183 |
+
if remove_dc_offset:
|
| 184 |
+
# Subtract each row/frame by its mean
|
| 185 |
+
row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
|
| 186 |
+
strided_input = strided_input - row_means
|
| 187 |
+
|
| 188 |
+
if raw_energy:
|
| 189 |
+
# Compute the log energy of each row/frame before applying preemphasis and
|
| 190 |
+
# window function
|
| 191 |
+
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
|
| 192 |
+
|
| 193 |
+
if preemphasis_coefficient != 0.0:
|
| 194 |
+
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
|
| 195 |
+
offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
|
| 196 |
+
0
|
| 197 |
+
) # size (m, window_size + 1)
|
| 198 |
+
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
|
| 199 |
+
|
| 200 |
+
# Apply window_function to each row/frame
|
| 201 |
+
window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
|
| 202 |
+
0
|
| 203 |
+
) # size (1, window_size)
|
| 204 |
+
strided_input = strided_input * window_function # size (m, window_size)
|
| 205 |
+
|
| 206 |
+
# Pad columns with zero until we reach size (m, padded_window_size)
|
| 207 |
+
if padded_window_size != window_size:
|
| 208 |
+
padding_right = padded_window_size - window_size
|
| 209 |
+
strided_input = torch.nn.functional.pad(
|
| 210 |
+
strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
|
| 211 |
+
).squeeze(0)
|
| 212 |
+
|
| 213 |
+
# Compute energy after window function (not the raw one)
|
| 214 |
+
if not raw_energy:
|
| 215 |
+
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
|
| 216 |
+
|
| 217 |
+
return strided_input, signal_log_energy
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
|
| 221 |
+
# subtracts the column mean of the tensor size (m, n) if subtract_mean=True
|
| 222 |
+
# it returns size (m, n)
|
| 223 |
+
if subtract_mean:
|
| 224 |
+
col_means = torch.mean(tensor, dim=0).unsqueeze(0)
|
| 225 |
+
tensor = tensor - col_means
|
| 226 |
+
return tensor
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def spectrogram(
|
| 230 |
+
waveform: Tensor,
|
| 231 |
+
blackman_coeff: float = 0.42,
|
| 232 |
+
channel: int = -1,
|
| 233 |
+
dither: float = 0.0,
|
| 234 |
+
energy_floor: float = 1.0,
|
| 235 |
+
frame_length: float = 25.0,
|
| 236 |
+
frame_shift: float = 10.0,
|
| 237 |
+
min_duration: float = 0.0,
|
| 238 |
+
preemphasis_coefficient: float = 0.97,
|
| 239 |
+
raw_energy: bool = True,
|
| 240 |
+
remove_dc_offset: bool = True,
|
| 241 |
+
round_to_power_of_two: bool = True,
|
| 242 |
+
sample_frequency: float = 16000.0,
|
| 243 |
+
snip_edges: bool = True,
|
| 244 |
+
subtract_mean: bool = False,
|
| 245 |
+
window_type: str = POVEY,
|
| 246 |
+
) -> Tensor:
|
| 247 |
+
r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
|
| 248 |
+
compute-spectrogram-feats.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
| 252 |
+
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
| 253 |
+
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
| 254 |
+
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
| 255 |
+
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
| 256 |
+
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
| 257 |
+
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
| 258 |
+
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
| 259 |
+
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
| 260 |
+
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
| 261 |
+
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
| 262 |
+
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
| 263 |
+
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
| 264 |
+
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
| 265 |
+
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
| 266 |
+
to FFT. (Default: ``True``)
|
| 267 |
+
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
| 268 |
+
specified there) (Default: ``16000.0``)
|
| 269 |
+
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
| 270 |
+
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
| 271 |
+
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
| 272 |
+
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
| 273 |
+
it this way. (Default: ``False``)
|
| 274 |
+
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
| 275 |
+
(Default: ``'povey'``)
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
Tensor: A spectrogram identical to what Kaldi would output. The shape is
|
| 279 |
+
(m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
|
| 280 |
+
"""
|
| 281 |
+
device, dtype = waveform.device, waveform.dtype
|
| 282 |
+
epsilon = _get_epsilon(device, dtype)
|
| 283 |
+
|
| 284 |
+
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
| 285 |
+
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if len(waveform) < min_duration * sample_frequency:
|
| 289 |
+
# signal is too short
|
| 290 |
+
return torch.empty(0)
|
| 291 |
+
|
| 292 |
+
strided_input, signal_log_energy = _get_window(
|
| 293 |
+
waveform,
|
| 294 |
+
padded_window_size,
|
| 295 |
+
window_size,
|
| 296 |
+
window_shift,
|
| 297 |
+
window_type,
|
| 298 |
+
blackman_coeff,
|
| 299 |
+
snip_edges,
|
| 300 |
+
raw_energy,
|
| 301 |
+
energy_floor,
|
| 302 |
+
dither,
|
| 303 |
+
remove_dc_offset,
|
| 304 |
+
preemphasis_coefficient,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# size (m, padded_window_size // 2 + 1, 2)
|
| 308 |
+
fft = torch.fft.rfft(strided_input)
|
| 309 |
+
|
| 310 |
+
# Convert the FFT into a power spectrum
|
| 311 |
+
power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
|
| 312 |
+
power_spectrum[:, 0] = signal_log_energy
|
| 313 |
+
|
| 314 |
+
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
|
| 315 |
+
return power_spectrum
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
def inverse_mel_scale_scalar(mel_freq: float) -> float:
|
| 319 |
+
return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
|
| 323 |
+
return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def mel_scale_scalar(freq: float) -> float:
|
| 327 |
+
return 1127.0 * math.log(1.0 + freq / 700.0)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def mel_scale(freq: Tensor) -> Tensor:
|
| 331 |
+
return 1127.0 * (1.0 + freq / 700.0).log()
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def vtln_warp_freq(
|
| 335 |
+
vtln_low_cutoff: float,
|
| 336 |
+
vtln_high_cutoff: float,
|
| 337 |
+
low_freq: float,
|
| 338 |
+
high_freq: float,
|
| 339 |
+
vtln_warp_factor: float,
|
| 340 |
+
freq: Tensor,
|
| 341 |
+
) -> Tensor:
|
| 342 |
+
r"""This computes a VTLN warping function that is not the same as HTK's one,
|
| 343 |
+
but has similar inputs (this function has the advantage of never producing
|
| 344 |
+
empty bins).
|
| 345 |
+
|
| 346 |
+
This function computes a warp function F(freq), defined between low_freq
|
| 347 |
+
and high_freq inclusive, with the following properties:
|
| 348 |
+
F(low_freq) == low_freq
|
| 349 |
+
F(high_freq) == high_freq
|
| 350 |
+
The function is continuous and piecewise linear with two inflection
|
| 351 |
+
points.
|
| 352 |
+
The lower inflection point (measured in terms of the unwarped
|
| 353 |
+
frequency) is at frequency l, determined as described below.
|
| 354 |
+
The higher inflection point is at a frequency h, determined as
|
| 355 |
+
described below.
|
| 356 |
+
If l <= f <= h, then F(f) = f/vtln_warp_factor.
|
| 357 |
+
If the higher inflection point (measured in terms of the unwarped
|
| 358 |
+
frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
|
| 359 |
+
Since (by the last point) F(h) == h/vtln_warp_factor, then
|
| 360 |
+
max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
|
| 361 |
+
h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
|
| 362 |
+
= vtln_high_cutoff * min(1, vtln_warp_factor).
|
| 363 |
+
If the lower inflection point (measured in terms of the unwarped
|
| 364 |
+
frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
|
| 365 |
+
This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
|
| 366 |
+
= vtln_low_cutoff * max(1, vtln_warp_factor)
|
| 367 |
+
Args:
|
| 368 |
+
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
|
| 369 |
+
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
|
| 370 |
+
low_freq (float): Lower frequency cutoffs in mel computation
|
| 371 |
+
high_freq (float): Upper frequency cutoffs in mel computation
|
| 372 |
+
vtln_warp_factor (float): Vtln warp factor
|
| 373 |
+
freq (Tensor): given frequency in Hz
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
Tensor: Freq after vtln warp
|
| 377 |
+
"""
|
| 378 |
+
assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
|
| 379 |
+
assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
|
| 380 |
+
l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
|
| 381 |
+
h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
|
| 382 |
+
scale = 1.0 / vtln_warp_factor
|
| 383 |
+
Fl = scale * l # F(l)
|
| 384 |
+
Fh = scale * h # F(h)
|
| 385 |
+
assert l > low_freq and h < high_freq
|
| 386 |
+
# slope of left part of the 3-piece linear function
|
| 387 |
+
scale_left = (Fl - low_freq) / (l - low_freq)
|
| 388 |
+
# [slope of center part is just "scale"]
|
| 389 |
+
|
| 390 |
+
# slope of right part of the 3-piece linear function
|
| 391 |
+
scale_right = (high_freq - Fh) / (high_freq - h)
|
| 392 |
+
|
| 393 |
+
res = torch.empty_like(freq)
|
| 394 |
+
|
| 395 |
+
outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq
|
| 396 |
+
before_l = torch.lt(freq, l) # freq < l
|
| 397 |
+
before_h = torch.lt(freq, h) # freq < h
|
| 398 |
+
after_h = torch.ge(freq, h) # freq >= h
|
| 399 |
+
|
| 400 |
+
# order of operations matter here (since there is overlapping frequency regions)
|
| 401 |
+
res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
|
| 402 |
+
res[before_h] = scale * freq[before_h]
|
| 403 |
+
res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
|
| 404 |
+
res[outside_low_high_freq] = freq[outside_low_high_freq]
|
| 405 |
+
|
| 406 |
+
return res
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def vtln_warp_mel_freq(
|
| 410 |
+
vtln_low_cutoff: float,
|
| 411 |
+
vtln_high_cutoff: float,
|
| 412 |
+
low_freq,
|
| 413 |
+
high_freq: float,
|
| 414 |
+
vtln_warp_factor: float,
|
| 415 |
+
mel_freq: Tensor,
|
| 416 |
+
) -> Tensor:
|
| 417 |
+
r"""
|
| 418 |
+
Args:
|
| 419 |
+
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
|
| 420 |
+
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
|
| 421 |
+
low_freq (float): Lower frequency cutoffs in mel computation
|
| 422 |
+
high_freq (float): Upper frequency cutoffs in mel computation
|
| 423 |
+
vtln_warp_factor (float): Vtln warp factor
|
| 424 |
+
mel_freq (Tensor): Given frequency in Mel
|
| 425 |
+
|
| 426 |
+
Returns:
|
| 427 |
+
Tensor: ``mel_freq`` after vtln warp
|
| 428 |
+
"""
|
| 429 |
+
return mel_scale(
|
| 430 |
+
vtln_warp_freq(
|
| 431 |
+
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
|
| 432 |
+
)
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def get_mel_banks(
|
| 437 |
+
num_bins: int,
|
| 438 |
+
window_length_padded: int,
|
| 439 |
+
sample_freq: float,
|
| 440 |
+
low_freq: float,
|
| 441 |
+
high_freq: float,
|
| 442 |
+
vtln_low: float,
|
| 443 |
+
vtln_high: float,
|
| 444 |
+
vtln_warp_factor: float,
|
| 445 |
+
device=None,
|
| 446 |
+
dtype=None,
|
| 447 |
+
) -> Tuple[Tensor, Tensor]:
|
| 448 |
+
"""
|
| 449 |
+
Returns:
|
| 450 |
+
(Tensor, Tensor): The tuple consists of ``bins`` (which is
|
| 451 |
+
melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
|
| 452 |
+
center frequencies of bins of size (``num_bins``)).
|
| 453 |
+
"""
|
| 454 |
+
assert num_bins > 3, "Must have at least 3 mel bins"
|
| 455 |
+
assert window_length_padded % 2 == 0
|
| 456 |
+
num_fft_bins = window_length_padded / 2
|
| 457 |
+
nyquist = 0.5 * sample_freq
|
| 458 |
+
|
| 459 |
+
if high_freq <= 0.0:
|
| 460 |
+
high_freq += nyquist
|
| 461 |
+
|
| 462 |
+
assert (0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq), (
|
| 463 |
+
"Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
# fft-bin width [think of it as Nyquist-freq / half-window-length]
|
| 467 |
+
fft_bin_width = sample_freq / window_length_padded
|
| 468 |
+
mel_low_freq = mel_scale_scalar(low_freq)
|
| 469 |
+
mel_high_freq = mel_scale_scalar(high_freq)
|
| 470 |
+
|
| 471 |
+
# divide by num_bins+1 in next line because of end-effects where the bins
|
| 472 |
+
# spread out to the sides.
|
| 473 |
+
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
|
| 474 |
+
|
| 475 |
+
if vtln_high < 0.0:
|
| 476 |
+
vtln_high += nyquist
|
| 477 |
+
|
| 478 |
+
assert vtln_warp_factor == 1.0 or (
|
| 479 |
+
(low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
|
| 480 |
+
), "Bad values in options: vtln-low {} and vtln-high {}, versus low-freq {} and high-freq {}".format(
|
| 481 |
+
vtln_low, vtln_high, low_freq, high_freq
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
bin = torch.arange(num_bins).unsqueeze(1)
|
| 485 |
+
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
|
| 486 |
+
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
|
| 487 |
+
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
|
| 488 |
+
|
| 489 |
+
if vtln_warp_factor != 1.0:
|
| 490 |
+
left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
|
| 491 |
+
center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
|
| 492 |
+
right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)
|
| 493 |
+
|
| 494 |
+
# center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
|
| 495 |
+
# size(1, num_fft_bins)
|
| 496 |
+
mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
|
| 497 |
+
|
| 498 |
+
# size (num_bins, num_fft_bins)
|
| 499 |
+
up_slope = (mel - left_mel) / (center_mel - left_mel)
|
| 500 |
+
down_slope = (right_mel - mel) / (right_mel - center_mel)
|
| 501 |
+
|
| 502 |
+
if vtln_warp_factor == 1.0:
|
| 503 |
+
# left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
|
| 504 |
+
bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
|
| 505 |
+
else:
|
| 506 |
+
# warping can move the order of left_mel, center_mel, right_mel anywhere
|
| 507 |
+
bins = torch.zeros_like(up_slope)
|
| 508 |
+
up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel
|
| 509 |
+
down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel
|
| 510 |
+
bins[up_idx] = up_slope[up_idx]
|
| 511 |
+
bins[down_idx] = down_slope[down_idx]
|
| 512 |
+
|
| 513 |
+
return bins.to(device=device, dtype=dtype) # , center_freqs
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
cache = {}
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def fbank(
|
| 520 |
+
waveform: Tensor,
|
| 521 |
+
blackman_coeff: float = 0.42,
|
| 522 |
+
channel: int = -1,
|
| 523 |
+
dither: float = 0.0,
|
| 524 |
+
energy_floor: float = 1.0,
|
| 525 |
+
frame_length: float = 25.0,
|
| 526 |
+
frame_shift: float = 10.0,
|
| 527 |
+
high_freq: float = 0.0,
|
| 528 |
+
htk_compat: bool = False,
|
| 529 |
+
low_freq: float = 20.0,
|
| 530 |
+
min_duration: float = 0.0,
|
| 531 |
+
num_mel_bins: int = 23,
|
| 532 |
+
preemphasis_coefficient: float = 0.97,
|
| 533 |
+
raw_energy: bool = True,
|
| 534 |
+
remove_dc_offset: bool = True,
|
| 535 |
+
round_to_power_of_two: bool = True,
|
| 536 |
+
sample_frequency: float = 16000.0,
|
| 537 |
+
snip_edges: bool = True,
|
| 538 |
+
subtract_mean: bool = False,
|
| 539 |
+
use_energy: bool = False,
|
| 540 |
+
use_log_fbank: bool = True,
|
| 541 |
+
use_power: bool = True,
|
| 542 |
+
vtln_high: float = -500.0,
|
| 543 |
+
vtln_low: float = 100.0,
|
| 544 |
+
vtln_warp: float = 1.0,
|
| 545 |
+
window_type: str = POVEY,
|
| 546 |
+
) -> Tensor:
|
| 547 |
+
r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
|
| 548 |
+
compute-fbank-feats.
|
| 549 |
+
|
| 550 |
+
Args:
|
| 551 |
+
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
| 552 |
+
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
| 553 |
+
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
| 554 |
+
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
| 555 |
+
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
| 556 |
+
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
| 557 |
+
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
| 558 |
+
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
| 559 |
+
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
| 560 |
+
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
| 561 |
+
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
|
| 562 |
+
(Default: ``0.0``)
|
| 563 |
+
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features
|
| 564 |
+
(need to change other parameters). (Default: ``False``)
|
| 565 |
+
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
|
| 566 |
+
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
| 567 |
+
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
|
| 568 |
+
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
| 569 |
+
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
| 570 |
+
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
| 571 |
+
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
| 572 |
+
to FFT. (Default: ``True``)
|
| 573 |
+
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
| 574 |
+
specified there) (Default: ``16000.0``)
|
| 575 |
+
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
| 576 |
+
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
| 577 |
+
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
| 578 |
+
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
| 579 |
+
it this way. (Default: ``False``)
|
| 580 |
+
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
|
| 581 |
+
use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
|
| 582 |
+
use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
|
| 583 |
+
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
|
| 584 |
+
negative, offset from high-mel-freq (Default: ``-500.0``)
|
| 585 |
+
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
|
| 586 |
+
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
|
| 587 |
+
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
| 588 |
+
(Default: ``'povey'``)
|
| 589 |
+
|
| 590 |
+
Returns:
|
| 591 |
+
Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
|
| 592 |
+
where m is calculated in _get_strided
|
| 593 |
+
"""
|
| 594 |
+
device, dtype = waveform.device, waveform.dtype
|
| 595 |
+
|
| 596 |
+
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
| 597 |
+
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
if len(waveform) < min_duration * sample_frequency:
|
| 601 |
+
# signal is too short
|
| 602 |
+
return torch.empty(0, device=device, dtype=dtype)
|
| 603 |
+
|
| 604 |
+
# strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
|
| 605 |
+
strided_input, signal_log_energy = _get_window(
|
| 606 |
+
waveform,
|
| 607 |
+
padded_window_size,
|
| 608 |
+
window_size,
|
| 609 |
+
window_shift,
|
| 610 |
+
window_type,
|
| 611 |
+
blackman_coeff,
|
| 612 |
+
snip_edges,
|
| 613 |
+
raw_energy,
|
| 614 |
+
energy_floor,
|
| 615 |
+
dither,
|
| 616 |
+
remove_dc_offset,
|
| 617 |
+
preemphasis_coefficient,
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
# size (m, padded_window_size // 2 + 1)
|
| 621 |
+
spectrum = torch.fft.rfft(strided_input).abs()
|
| 622 |
+
if use_power:
|
| 623 |
+
spectrum = spectrum.pow(2.0)
|
| 624 |
+
|
| 625 |
+
# size (num_mel_bins, padded_window_size // 2)
|
| 626 |
+
# print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
|
| 627 |
+
|
| 628 |
+
cache_key = "%s-%s-%s-%s-%s-%s-%s-%s-%s-%s" % (
|
| 629 |
+
num_mel_bins,
|
| 630 |
+
padded_window_size,
|
| 631 |
+
sample_frequency,
|
| 632 |
+
low_freq,
|
| 633 |
+
high_freq,
|
| 634 |
+
vtln_low,
|
| 635 |
+
vtln_high,
|
| 636 |
+
vtln_warp,
|
| 637 |
+
device,
|
| 638 |
+
dtype,
|
| 639 |
+
)
|
| 640 |
+
if cache_key not in cache:
|
| 641 |
+
mel_energies = get_mel_banks(
|
| 642 |
+
num_mel_bins,
|
| 643 |
+
padded_window_size,
|
| 644 |
+
sample_frequency,
|
| 645 |
+
low_freq,
|
| 646 |
+
high_freq,
|
| 647 |
+
vtln_low,
|
| 648 |
+
vtln_high,
|
| 649 |
+
vtln_warp,
|
| 650 |
+
device,
|
| 651 |
+
dtype,
|
| 652 |
+
)
|
| 653 |
+
cache[cache_key] = mel_energies
|
| 654 |
+
else:
|
| 655 |
+
mel_energies = cache[cache_key]
|
| 656 |
+
|
| 657 |
+
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
|
| 658 |
+
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
|
| 659 |
+
|
| 660 |
+
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
|
| 661 |
+
mel_energies = torch.mm(spectrum, mel_energies.T)
|
| 662 |
+
if use_log_fbank:
|
| 663 |
+
# avoid log of zero (which should be prevented anyway by dithering)
|
| 664 |
+
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
|
| 665 |
+
|
| 666 |
+
# if use_energy then add it as the last column for htk_compat == true else first column
|
| 667 |
+
if use_energy:
|
| 668 |
+
signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
|
| 669 |
+
# returns size (m, num_mel_bins + 1)
|
| 670 |
+
if htk_compat:
|
| 671 |
+
mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1)
|
| 672 |
+
else:
|
| 673 |
+
mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
|
| 674 |
+
|
| 675 |
+
mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
|
| 676 |
+
return mel_energies
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
|
| 680 |
+
# returns a dct matrix of size (num_mel_bins, num_ceps)
|
| 681 |
+
# size (num_mel_bins, num_mel_bins)
|
| 682 |
+
dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
|
| 683 |
+
# kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
|
| 684 |
+
# this would be the first column in the dct_matrix for torchaudio as it expects a
|
| 685 |
+
# right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
|
| 686 |
+
# expects a left multiply e.g. dct_matrix * vector).
|
| 687 |
+
dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
|
| 688 |
+
dct_matrix = dct_matrix[:, :num_ceps]
|
| 689 |
+
return dct_matrix
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
|
| 693 |
+
# returns size (num_ceps)
|
| 694 |
+
# Compute liftering coefficients (scaling on cepstral coeffs)
|
| 695 |
+
# coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
|
| 696 |
+
i = torch.arange(num_ceps)
|
| 697 |
+
return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
def mfcc(
|
| 701 |
+
waveform: Tensor,
|
| 702 |
+
blackman_coeff: float = 0.42,
|
| 703 |
+
cepstral_lifter: float = 22.0,
|
| 704 |
+
channel: int = -1,
|
| 705 |
+
dither: float = 0.0,
|
| 706 |
+
energy_floor: float = 1.0,
|
| 707 |
+
frame_length: float = 25.0,
|
| 708 |
+
frame_shift: float = 10.0,
|
| 709 |
+
high_freq: float = 0.0,
|
| 710 |
+
htk_compat: bool = False,
|
| 711 |
+
low_freq: float = 20.0,
|
| 712 |
+
num_ceps: int = 13,
|
| 713 |
+
min_duration: float = 0.0,
|
| 714 |
+
num_mel_bins: int = 23,
|
| 715 |
+
preemphasis_coefficient: float = 0.97,
|
| 716 |
+
raw_energy: bool = True,
|
| 717 |
+
remove_dc_offset: bool = True,
|
| 718 |
+
round_to_power_of_two: bool = True,
|
| 719 |
+
sample_frequency: float = 16000.0,
|
| 720 |
+
snip_edges: bool = True,
|
| 721 |
+
subtract_mean: bool = False,
|
| 722 |
+
use_energy: bool = False,
|
| 723 |
+
vtln_high: float = -500.0,
|
| 724 |
+
vtln_low: float = 100.0,
|
| 725 |
+
vtln_warp: float = 1.0,
|
| 726 |
+
window_type: str = POVEY,
|
| 727 |
+
) -> Tensor:
|
| 728 |
+
r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
|
| 729 |
+
compute-mfcc-feats.
|
| 730 |
+
|
| 731 |
+
Args:
|
| 732 |
+
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
| 733 |
+
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
| 734 |
+
cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
|
| 735 |
+
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
| 736 |
+
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
| 737 |
+
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
| 738 |
+
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
| 739 |
+
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
| 740 |
+
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
| 741 |
+
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
| 742 |
+
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
| 743 |
+
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
|
| 744 |
+
(Default: ``0.0``)
|
| 745 |
+
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
|
| 746 |
+
features (need to change other parameters). (Default: ``False``)
|
| 747 |
+
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
|
| 748 |
+
num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
|
| 749 |
+
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
| 750 |
+
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
|
| 751 |
+
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
| 752 |
+
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
| 753 |
+
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
| 754 |
+
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
| 755 |
+
to FFT. (Default: ``True``)
|
| 756 |
+
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
| 757 |
+
specified there) (Default: ``16000.0``)
|
| 758 |
+
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
| 759 |
+
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
| 760 |
+
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
| 761 |
+
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
| 762 |
+
it this way. (Default: ``False``)
|
| 763 |
+
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
|
| 764 |
+
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
|
| 765 |
+
negative, offset from high-mel-freq (Default: ``-500.0``)
|
| 766 |
+
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
|
| 767 |
+
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
|
| 768 |
+
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
| 769 |
+
(Default: ``"povey"``)
|
| 770 |
+
|
| 771 |
+
Returns:
|
| 772 |
+
Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
|
| 773 |
+
where m is calculated in _get_strided
|
| 774 |
+
"""
|
| 775 |
+
assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
|
| 776 |
+
|
| 777 |
+
device, dtype = waveform.device, waveform.dtype
|
| 778 |
+
|
| 779 |
+
# The mel_energies should not be squared (use_power=True), not have mean subtracted
|
| 780 |
+
# (subtract_mean=False), and use log (use_log_fbank=True).
|
| 781 |
+
# size (m, num_mel_bins + use_energy)
|
| 782 |
+
feature = fbank(
|
| 783 |
+
waveform=waveform,
|
| 784 |
+
blackman_coeff=blackman_coeff,
|
| 785 |
+
channel=channel,
|
| 786 |
+
dither=dither,
|
| 787 |
+
energy_floor=energy_floor,
|
| 788 |
+
frame_length=frame_length,
|
| 789 |
+
frame_shift=frame_shift,
|
| 790 |
+
high_freq=high_freq,
|
| 791 |
+
htk_compat=htk_compat,
|
| 792 |
+
low_freq=low_freq,
|
| 793 |
+
min_duration=min_duration,
|
| 794 |
+
num_mel_bins=num_mel_bins,
|
| 795 |
+
preemphasis_coefficient=preemphasis_coefficient,
|
| 796 |
+
raw_energy=raw_energy,
|
| 797 |
+
remove_dc_offset=remove_dc_offset,
|
| 798 |
+
round_to_power_of_two=round_to_power_of_two,
|
| 799 |
+
sample_frequency=sample_frequency,
|
| 800 |
+
snip_edges=snip_edges,
|
| 801 |
+
subtract_mean=False,
|
| 802 |
+
use_energy=use_energy,
|
| 803 |
+
use_log_fbank=True,
|
| 804 |
+
use_power=True,
|
| 805 |
+
vtln_high=vtln_high,
|
| 806 |
+
vtln_low=vtln_low,
|
| 807 |
+
vtln_warp=vtln_warp,
|
| 808 |
+
window_type=window_type,
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
if use_energy:
|
| 812 |
+
# size (m)
|
| 813 |
+
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
|
| 814 |
+
# offset is 0 if htk_compat==True else 1
|
| 815 |
+
mel_offset = int(not htk_compat)
|
| 816 |
+
feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
|
| 817 |
+
|
| 818 |
+
# size (num_mel_bins, num_ceps)
|
| 819 |
+
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
|
| 820 |
+
|
| 821 |
+
# size (m, num_ceps)
|
| 822 |
+
feature = feature.matmul(dct_matrix)
|
| 823 |
+
|
| 824 |
+
if cepstral_lifter != 0.0:
|
| 825 |
+
# size (1, num_ceps)
|
| 826 |
+
lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
|
| 827 |
+
feature *= lifter_coeffs.to(device=device, dtype=dtype)
|
| 828 |
+
|
| 829 |
+
# if use_energy then replace the last column for htk_compat == true else first column
|
| 830 |
+
if use_energy:
|
| 831 |
+
feature[:, 0] = signal_log_energy
|
| 832 |
+
|
| 833 |
+
if htk_compat:
|
| 834 |
+
energy = feature[:, 0].unsqueeze(1) # size (m, 1)
|
| 835 |
+
feature = feature[:, 1:] # size (m, num_ceps - 1)
|
| 836 |
+
if not use_energy:
|
| 837 |
+
# scale on C0 (actually removing a scale we previously added that's
|
| 838 |
+
# part of one common definition of the cosine transform.)
|
| 839 |
+
energy *= math.sqrt(2)
|
| 840 |
+
|
| 841 |
+
feature = torch.cat((feature, energy), dim=1)
|
| 842 |
+
|
| 843 |
+
feature = _subtract_column_mean(feature, subtract_mean)
|
| 844 |
+
return feature
|
GPT_SoVITS/eres2net/pooling_layers.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
| 3 |
+
|
| 4 |
+
"""This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TAP(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Temporal average pooling, only first-order mean is considered
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
def __init__(self, **kwargs):
|
| 16 |
+
super(TAP, self).__init__()
|
| 17 |
+
|
| 18 |
+
def forward(self, x):
|
| 19 |
+
pooling_mean = x.mean(dim=-1)
|
| 20 |
+
# To be compatable with 2D input
|
| 21 |
+
pooling_mean = pooling_mean.flatten(start_dim=1)
|
| 22 |
+
return pooling_mean
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class TSDP(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
Temporal standard deviation pooling, only second-order std is considered
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, **kwargs):
|
| 31 |
+
super(TSDP, self).__init__()
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
# The last dimension is the temporal axis
|
| 35 |
+
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
|
| 36 |
+
pooling_std = pooling_std.flatten(start_dim=1)
|
| 37 |
+
return pooling_std
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class TSTP(nn.Module):
|
| 41 |
+
"""
|
| 42 |
+
Temporal statistics pooling, concatenate mean and std, which is used in
|
| 43 |
+
x-vector
|
| 44 |
+
Comment: simple concatenation can not make full use of both statistics
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, **kwargs):
|
| 48 |
+
super(TSTP, self).__init__()
|
| 49 |
+
|
| 50 |
+
def forward(self, x):
|
| 51 |
+
# The last dimension is the temporal axis
|
| 52 |
+
pooling_mean = x.mean(dim=-1)
|
| 53 |
+
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
|
| 54 |
+
pooling_mean = pooling_mean.flatten(start_dim=1)
|
| 55 |
+
pooling_std = pooling_std.flatten(start_dim=1)
|
| 56 |
+
|
| 57 |
+
stats = torch.cat((pooling_mean, pooling_std), 1)
|
| 58 |
+
return stats
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class ASTP(nn.Module):
|
| 62 |
+
"""Attentive statistics pooling: Channel- and context-dependent
|
| 63 |
+
statistics pooling, first used in ECAPA_TDNN.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
|
| 67 |
+
super(ASTP, self).__init__()
|
| 68 |
+
self.global_context_att = global_context_att
|
| 69 |
+
|
| 70 |
+
# Use Conv1d with stride == 1 rather than Linear, then we don't
|
| 71 |
+
# need to transpose inputs.
|
| 72 |
+
if global_context_att:
|
| 73 |
+
self.linear1 = nn.Conv1d(in_dim * 3, bottleneck_dim, kernel_size=1) # equals W and b in the paper
|
| 74 |
+
else:
|
| 75 |
+
self.linear1 = nn.Conv1d(in_dim, bottleneck_dim, kernel_size=1) # equals W and b in the paper
|
| 76 |
+
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, kernel_size=1) # equals V and k in the paper
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
"""
|
| 80 |
+
x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
|
| 81 |
+
or a 4-dimensional tensor in resnet architecture (B,C,F,T)
|
| 82 |
+
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
| 83 |
+
"""
|
| 84 |
+
if len(x.shape) == 4:
|
| 85 |
+
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
|
| 86 |
+
assert len(x.shape) == 3
|
| 87 |
+
|
| 88 |
+
if self.global_context_att:
|
| 89 |
+
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
| 90 |
+
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
| 91 |
+
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
| 92 |
+
else:
|
| 93 |
+
x_in = x
|
| 94 |
+
|
| 95 |
+
# DON'T use ReLU here! ReLU may be hard to converge.
|
| 96 |
+
alpha = torch.tanh(self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
|
| 97 |
+
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
| 98 |
+
mean = torch.sum(alpha * x, dim=2)
|
| 99 |
+
var = torch.sum(alpha * (x**2), dim=2) - mean**2
|
| 100 |
+
std = torch.sqrt(var.clamp(min=1e-10))
|
| 101 |
+
return torch.cat([mean, std], dim=1)
|
GPT_SoVITS/f5_tts/model/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .backbones.dit import DiT
|
| 2 |
+
|
| 3 |
+
__all__ = ["DiT"]
|
GPT_SoVITS/f5_tts/model/backbones/README.md
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Backbones quick introduction
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
### unett.py
|
| 5 |
+
- flat unet transformer
|
| 6 |
+
- structure same as in e2-tts & voicebox paper except using rotary pos emb
|
| 7 |
+
- update: allow possible abs pos emb & convnextv2 blocks for embedded text before concat
|
| 8 |
+
|
| 9 |
+
### dit.py
|
| 10 |
+
- adaln-zero dit
|
| 11 |
+
- embedded timestep as condition
|
| 12 |
+
- concatted noised_input + masked_cond + embedded_text, linear proj in
|
| 13 |
+
- possible abs pos emb & convnextv2 blocks for embedded text before concat
|
| 14 |
+
- possible long skip connection (first layer to last layer)
|
| 15 |
+
|
| 16 |
+
### mmdit.py
|
| 17 |
+
- sd3 structure
|
| 18 |
+
- timestep as condition
|
| 19 |
+
- left stream: text embedded and applied a abs pos emb
|
| 20 |
+
- right stream: masked_cond & noised_input concatted and with same conv pos emb as unett
|
GPT_SoVITS/f5_tts/model/backbones/dit.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ein notation:
|
| 3 |
+
b - batch
|
| 4 |
+
n - sequence
|
| 5 |
+
nt - text sequence
|
| 6 |
+
nw - raw wave length
|
| 7 |
+
d - dimension
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn
|
| 14 |
+
from torch.utils.checkpoint import checkpoint
|
| 15 |
+
from x_transformers.x_transformers import RotaryEmbedding
|
| 16 |
+
|
| 17 |
+
from GPT_SoVITS.module.commons import sequence_mask
|
| 18 |
+
|
| 19 |
+
from ..modules import (
|
| 20 |
+
AdaLayerNormZero_Final,
|
| 21 |
+
ConvNeXtV2Block,
|
| 22 |
+
ConvPositionEmbedding,
|
| 23 |
+
DiTBlock,
|
| 24 |
+
TimestepEmbedding,
|
| 25 |
+
get_pos_embed_indices,
|
| 26 |
+
precompute_freqs_cis,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class TextEmbedding(nn.Module):
|
| 31 |
+
def __init__(self, text_dim, conv_layers=0, conv_mult=2):
|
| 32 |
+
super().__init__()
|
| 33 |
+
if conv_layers > 0:
|
| 34 |
+
self.extra_modeling = True
|
| 35 |
+
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
| 36 |
+
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
| 37 |
+
self.text_blocks = nn.Sequential(
|
| 38 |
+
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
| 39 |
+
)
|
| 40 |
+
else:
|
| 41 |
+
self.extra_modeling = False
|
| 42 |
+
|
| 43 |
+
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
| 44 |
+
batch, text_len = text.shape[0], text.shape[1]
|
| 45 |
+
|
| 46 |
+
if drop_text: # cfg for text
|
| 47 |
+
text = torch.zeros_like(text)
|
| 48 |
+
|
| 49 |
+
# possible extra modeling
|
| 50 |
+
if self.extra_modeling:
|
| 51 |
+
# sinus pos emb
|
| 52 |
+
batch_start = torch.zeros((batch,), dtype=torch.long)
|
| 53 |
+
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
| 54 |
+
text_pos_embed = self.freqs_cis[pos_idx]
|
| 55 |
+
|
| 56 |
+
# print(23333333,text.shape,text_pos_embed.shape)#torch.Size([7, 465, 256]) torch.Size([7, 465, 256])
|
| 57 |
+
|
| 58 |
+
text = text + text_pos_embed
|
| 59 |
+
|
| 60 |
+
# convnextv2 blocks
|
| 61 |
+
text = self.text_blocks(text)
|
| 62 |
+
|
| 63 |
+
return text
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# noised input audio and context mixing embedding
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class InputEmbedding(nn.Module):
|
| 70 |
+
def __init__(self, mel_dim, text_dim, out_dim):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
| 73 |
+
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
| 74 |
+
|
| 75 |
+
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
|
| 76 |
+
if drop_audio_cond: # cfg for cond audio
|
| 77 |
+
cond = torch.zeros_like(cond)
|
| 78 |
+
|
| 79 |
+
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
|
| 80 |
+
x = self.conv_pos_embed(x) + x
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# Transformer backbone using DiT blocks
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class DiT(nn.Module):
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
*,
|
| 91 |
+
dim,
|
| 92 |
+
depth=8,
|
| 93 |
+
heads=8,
|
| 94 |
+
dim_head=64,
|
| 95 |
+
dropout=0.1,
|
| 96 |
+
ff_mult=4,
|
| 97 |
+
mel_dim=100,
|
| 98 |
+
text_dim=None,
|
| 99 |
+
conv_layers=0,
|
| 100 |
+
long_skip_connection=False,
|
| 101 |
+
):
|
| 102 |
+
super().__init__()
|
| 103 |
+
|
| 104 |
+
self.time_embed = TimestepEmbedding(dim)
|
| 105 |
+
self.d_embed = TimestepEmbedding(dim)
|
| 106 |
+
if text_dim is None:
|
| 107 |
+
text_dim = mel_dim
|
| 108 |
+
self.text_embed = TextEmbedding(text_dim, conv_layers=conv_layers)
|
| 109 |
+
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
| 110 |
+
|
| 111 |
+
self.rotary_embed = RotaryEmbedding(dim_head)
|
| 112 |
+
|
| 113 |
+
self.dim = dim
|
| 114 |
+
self.depth = depth
|
| 115 |
+
|
| 116 |
+
self.transformer_blocks = nn.ModuleList(
|
| 117 |
+
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
|
| 118 |
+
)
|
| 119 |
+
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
| 120 |
+
|
| 121 |
+
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
| 122 |
+
self.proj_out = nn.Linear(dim, mel_dim)
|
| 123 |
+
|
| 124 |
+
def ckpt_wrapper(self, module):
|
| 125 |
+
# https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
|
| 126 |
+
def ckpt_forward(*inputs):
|
| 127 |
+
outputs = module(*inputs)
|
| 128 |
+
return outputs
|
| 129 |
+
|
| 130 |
+
return ckpt_forward
|
| 131 |
+
|
| 132 |
+
def forward( # x, prompt_x, x_lens, t, style,cond
|
| 133 |
+
self, # d is channel,n is T
|
| 134 |
+
x0: float["b n d"], # nosied input audio # noqa: F722
|
| 135 |
+
cond0: float["b n d"], # masked cond audio # noqa: F722
|
| 136 |
+
x_lens,
|
| 137 |
+
time: float["b"] | float[""], # time step # noqa: F821 F722
|
| 138 |
+
dt_base_bootstrap,
|
| 139 |
+
text0, # : int["b nt"] # noqa: F722#####condition feature
|
| 140 |
+
use_grad_ckpt=False, # bool
|
| 141 |
+
###no-use
|
| 142 |
+
drop_audio_cond=False, # cfg for cond audio
|
| 143 |
+
drop_text=False, # cfg for text
|
| 144 |
+
# mask: bool["b n"] | None = None, # noqa: F722
|
| 145 |
+
infer=False, # bool
|
| 146 |
+
text_cache=None, # torch tensor as text_embed
|
| 147 |
+
dt_cache=None, # torch tensor as dt
|
| 148 |
+
):
|
| 149 |
+
x = x0.transpose(2, 1)
|
| 150 |
+
cond = cond0.transpose(2, 1)
|
| 151 |
+
text = text0.transpose(2, 1)
|
| 152 |
+
mask = sequence_mask(x_lens, max_length=x.size(1)).to(x.device)
|
| 153 |
+
|
| 154 |
+
batch, seq_len = x.shape[0], x.shape[1]
|
| 155 |
+
if time.ndim == 0:
|
| 156 |
+
time = time.repeat(batch)
|
| 157 |
+
|
| 158 |
+
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
| 159 |
+
t = self.time_embed(time)
|
| 160 |
+
if infer and dt_cache is not None:
|
| 161 |
+
dt = dt_cache
|
| 162 |
+
else:
|
| 163 |
+
dt = self.d_embed(dt_base_bootstrap)
|
| 164 |
+
t += dt
|
| 165 |
+
|
| 166 |
+
if infer and text_cache is not None:
|
| 167 |
+
text_embed = text_cache
|
| 168 |
+
else:
|
| 169 |
+
text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change
|
| 170 |
+
|
| 171 |
+
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
| 172 |
+
|
| 173 |
+
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
| 174 |
+
|
| 175 |
+
if self.long_skip_connection is not None:
|
| 176 |
+
residual = x
|
| 177 |
+
|
| 178 |
+
for block in self.transformer_blocks:
|
| 179 |
+
if use_grad_ckpt:
|
| 180 |
+
x = checkpoint(self.ckpt_wrapper(block), x, t, mask, rope, use_reentrant=False)
|
| 181 |
+
else:
|
| 182 |
+
x = block(x, t, mask=mask, rope=rope)
|
| 183 |
+
|
| 184 |
+
if self.long_skip_connection is not None:
|
| 185 |
+
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
| 186 |
+
|
| 187 |
+
x = self.norm_out(x, t)
|
| 188 |
+
output = self.proj_out(x)
|
| 189 |
+
|
| 190 |
+
if infer:
|
| 191 |
+
return output, text_embed, dt
|
| 192 |
+
else:
|
| 193 |
+
return output
|
GPT_SoVITS/f5_tts/model/backbones/mmdit.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ein notation:
|
| 3 |
+
b - batch
|
| 4 |
+
n - sequence
|
| 5 |
+
nt - text sequence
|
| 6 |
+
nw - raw wave length
|
| 7 |
+
d - dimension
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn
|
| 14 |
+
from x_transformers.x_transformers import RotaryEmbedding
|
| 15 |
+
|
| 16 |
+
from ..modules import (
|
| 17 |
+
AdaLayerNormZero_Final,
|
| 18 |
+
ConvPositionEmbedding,
|
| 19 |
+
MMDiTBlock,
|
| 20 |
+
TimestepEmbedding,
|
| 21 |
+
get_pos_embed_indices,
|
| 22 |
+
precompute_freqs_cis,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# text embedding
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class TextEmbedding(nn.Module):
|
| 29 |
+
def __init__(self, out_dim, text_num_embeds):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
|
| 32 |
+
|
| 33 |
+
self.precompute_max_pos = 1024
|
| 34 |
+
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
|
| 35 |
+
|
| 36 |
+
def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
|
| 37 |
+
text = text + 1
|
| 38 |
+
if drop_text:
|
| 39 |
+
text = torch.zeros_like(text)
|
| 40 |
+
text = self.text_embed(text)
|
| 41 |
+
|
| 42 |
+
# sinus pos emb
|
| 43 |
+
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
|
| 44 |
+
batch_text_len = text.shape[1]
|
| 45 |
+
pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
|
| 46 |
+
text_pos_embed = self.freqs_cis[pos_idx]
|
| 47 |
+
|
| 48 |
+
text = text + text_pos_embed
|
| 49 |
+
|
| 50 |
+
return text
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# noised input & masked cond audio embedding
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class AudioEmbedding(nn.Module):
|
| 57 |
+
def __init__(self, in_dim, out_dim):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.linear = nn.Linear(2 * in_dim, out_dim)
|
| 60 |
+
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
|
| 61 |
+
|
| 62 |
+
def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
|
| 63 |
+
if drop_audio_cond:
|
| 64 |
+
cond = torch.zeros_like(cond)
|
| 65 |
+
x = torch.cat((x, cond), dim=-1)
|
| 66 |
+
x = self.linear(x)
|
| 67 |
+
x = self.conv_pos_embed(x) + x
|
| 68 |
+
return x
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Transformer backbone using MM-DiT blocks
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class MMDiT(nn.Module):
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
*,
|
| 78 |
+
dim,
|
| 79 |
+
depth=8,
|
| 80 |
+
heads=8,
|
| 81 |
+
dim_head=64,
|
| 82 |
+
dropout=0.1,
|
| 83 |
+
ff_mult=4,
|
| 84 |
+
text_num_embeds=256,
|
| 85 |
+
mel_dim=100,
|
| 86 |
+
):
|
| 87 |
+
super().__init__()
|
| 88 |
+
|
| 89 |
+
self.time_embed = TimestepEmbedding(dim)
|
| 90 |
+
self.text_embed = TextEmbedding(dim, text_num_embeds)
|
| 91 |
+
self.audio_embed = AudioEmbedding(mel_dim, dim)
|
| 92 |
+
|
| 93 |
+
self.rotary_embed = RotaryEmbedding(dim_head)
|
| 94 |
+
|
| 95 |
+
self.dim = dim
|
| 96 |
+
self.depth = depth
|
| 97 |
+
|
| 98 |
+
self.transformer_blocks = nn.ModuleList(
|
| 99 |
+
[
|
| 100 |
+
MMDiTBlock(
|
| 101 |
+
dim=dim,
|
| 102 |
+
heads=heads,
|
| 103 |
+
dim_head=dim_head,
|
| 104 |
+
dropout=dropout,
|
| 105 |
+
ff_mult=ff_mult,
|
| 106 |
+
context_pre_only=i == depth - 1,
|
| 107 |
+
)
|
| 108 |
+
for i in range(depth)
|
| 109 |
+
]
|
| 110 |
+
)
|
| 111 |
+
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
| 112 |
+
self.proj_out = nn.Linear(dim, mel_dim)
|
| 113 |
+
|
| 114 |
+
def forward(
|
| 115 |
+
self,
|
| 116 |
+
x: float["b n d"], # nosied input audio # noqa: F722
|
| 117 |
+
cond: float["b n d"], # masked cond audio # noqa: F722
|
| 118 |
+
text: int["b nt"], # text # noqa: F722
|
| 119 |
+
time: float["b"] | float[""], # time step # noqa: F821 F722
|
| 120 |
+
drop_audio_cond, # cfg for cond audio
|
| 121 |
+
drop_text, # cfg for text
|
| 122 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
| 123 |
+
):
|
| 124 |
+
batch = x.shape[0]
|
| 125 |
+
if time.ndim == 0:
|
| 126 |
+
time = time.repeat(batch)
|
| 127 |
+
|
| 128 |
+
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
| 129 |
+
t = self.time_embed(time)
|
| 130 |
+
c = self.text_embed(text, drop_text=drop_text)
|
| 131 |
+
x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
|
| 132 |
+
|
| 133 |
+
seq_len = x.shape[1]
|
| 134 |
+
text_len = text.shape[1]
|
| 135 |
+
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
|
| 136 |
+
rope_text = self.rotary_embed.forward_from_seq_len(text_len)
|
| 137 |
+
|
| 138 |
+
for block in self.transformer_blocks:
|
| 139 |
+
c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
|
| 140 |
+
|
| 141 |
+
x = self.norm_out(x, t)
|
| 142 |
+
output = self.proj_out(x)
|
| 143 |
+
|
| 144 |
+
return output
|
GPT_SoVITS/f5_tts/model/backbones/unett.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ein notation:
|
| 3 |
+
b - batch
|
| 4 |
+
n - sequence
|
| 5 |
+
nt - text sequence
|
| 6 |
+
nw - raw wave length
|
| 7 |
+
d - dimension
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from typing import Literal
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
from torch import nn
|
| 17 |
+
from x_transformers import RMSNorm
|
| 18 |
+
from x_transformers.x_transformers import RotaryEmbedding
|
| 19 |
+
|
| 20 |
+
from ..modules import (
|
| 21 |
+
Attention,
|
| 22 |
+
AttnProcessor,
|
| 23 |
+
ConvNeXtV2Block,
|
| 24 |
+
ConvPositionEmbedding,
|
| 25 |
+
FeedForward,
|
| 26 |
+
TimestepEmbedding,
|
| 27 |
+
get_pos_embed_indices,
|
| 28 |
+
precompute_freqs_cis,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# Text embedding
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class TextEmbedding(nn.Module):
|
| 35 |
+
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
| 38 |
+
|
| 39 |
+
if conv_layers > 0:
|
| 40 |
+
self.extra_modeling = True
|
| 41 |
+
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
| 42 |
+
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
| 43 |
+
self.text_blocks = nn.Sequential(
|
| 44 |
+
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
| 45 |
+
)
|
| 46 |
+
else:
|
| 47 |
+
self.extra_modeling = False
|
| 48 |
+
|
| 49 |
+
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
| 50 |
+
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
| 51 |
+
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
| 52 |
+
batch, text_len = text.shape[0], text.shape[1]
|
| 53 |
+
text = F.pad(text, (0, seq_len - text_len), value=0)
|
| 54 |
+
|
| 55 |
+
if drop_text: # cfg for text
|
| 56 |
+
text = torch.zeros_like(text)
|
| 57 |
+
|
| 58 |
+
text = self.text_embed(text) # b n -> b n d
|
| 59 |
+
|
| 60 |
+
# possible extra modeling
|
| 61 |
+
if self.extra_modeling:
|
| 62 |
+
# sinus pos emb
|
| 63 |
+
batch_start = torch.zeros((batch,), dtype=torch.long)
|
| 64 |
+
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
| 65 |
+
text_pos_embed = self.freqs_cis[pos_idx]
|
| 66 |
+
text = text + text_pos_embed
|
| 67 |
+
|
| 68 |
+
# convnextv2 blocks
|
| 69 |
+
text = self.text_blocks(text)
|
| 70 |
+
|
| 71 |
+
return text
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# noised input audio and context mixing embedding
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class InputEmbedding(nn.Module):
|
| 78 |
+
def __init__(self, mel_dim, text_dim, out_dim):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
| 81 |
+
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
|
| 82 |
+
|
| 83 |
+
def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False): # noqa: F722
|
| 84 |
+
if drop_audio_cond: # cfg for cond audio
|
| 85 |
+
cond = torch.zeros_like(cond)
|
| 86 |
+
|
| 87 |
+
x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
|
| 88 |
+
x = self.conv_pos_embed(x) + x
|
| 89 |
+
return x
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Flat UNet Transformer backbone
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class UNetT(nn.Module):
|
| 96 |
+
def __init__(
|
| 97 |
+
self,
|
| 98 |
+
*,
|
| 99 |
+
dim,
|
| 100 |
+
depth=8,
|
| 101 |
+
heads=8,
|
| 102 |
+
dim_head=64,
|
| 103 |
+
dropout=0.1,
|
| 104 |
+
ff_mult=4,
|
| 105 |
+
mel_dim=100,
|
| 106 |
+
text_num_embeds=256,
|
| 107 |
+
text_dim=None,
|
| 108 |
+
conv_layers=0,
|
| 109 |
+
skip_connect_type: Literal["add", "concat", "none"] = "concat",
|
| 110 |
+
):
|
| 111 |
+
super().__init__()
|
| 112 |
+
assert depth % 2 == 0, "UNet-Transformer's depth should be even."
|
| 113 |
+
|
| 114 |
+
self.time_embed = TimestepEmbedding(dim)
|
| 115 |
+
if text_dim is None:
|
| 116 |
+
text_dim = mel_dim
|
| 117 |
+
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers)
|
| 118 |
+
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
| 119 |
+
|
| 120 |
+
self.rotary_embed = RotaryEmbedding(dim_head)
|
| 121 |
+
|
| 122 |
+
# transformer layers & skip connections
|
| 123 |
+
|
| 124 |
+
self.dim = dim
|
| 125 |
+
self.skip_connect_type = skip_connect_type
|
| 126 |
+
needs_skip_proj = skip_connect_type == "concat"
|
| 127 |
+
|
| 128 |
+
self.depth = depth
|
| 129 |
+
self.layers = nn.ModuleList([])
|
| 130 |
+
|
| 131 |
+
for idx in range(depth):
|
| 132 |
+
is_later_half = idx >= (depth // 2)
|
| 133 |
+
|
| 134 |
+
attn_norm = RMSNorm(dim)
|
| 135 |
+
attn = Attention(
|
| 136 |
+
processor=AttnProcessor(),
|
| 137 |
+
dim=dim,
|
| 138 |
+
heads=heads,
|
| 139 |
+
dim_head=dim_head,
|
| 140 |
+
dropout=dropout,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
ff_norm = RMSNorm(dim)
|
| 144 |
+
ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
| 145 |
+
|
| 146 |
+
skip_proj = nn.Linear(dim * 2, dim, bias=False) if needs_skip_proj and is_later_half else None
|
| 147 |
+
|
| 148 |
+
self.layers.append(
|
| 149 |
+
nn.ModuleList(
|
| 150 |
+
[
|
| 151 |
+
skip_proj,
|
| 152 |
+
attn_norm,
|
| 153 |
+
attn,
|
| 154 |
+
ff_norm,
|
| 155 |
+
ff,
|
| 156 |
+
]
|
| 157 |
+
)
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
self.norm_out = RMSNorm(dim)
|
| 161 |
+
self.proj_out = nn.Linear(dim, mel_dim)
|
| 162 |
+
|
| 163 |
+
def forward(
|
| 164 |
+
self,
|
| 165 |
+
x: float["b n d"], # nosied input audio # noqa: F722
|
| 166 |
+
cond: float["b n d"], # masked cond audio # noqa: F722
|
| 167 |
+
text: int["b nt"], # text # noqa: F722
|
| 168 |
+
time: float["b"] | float[""], # time step # noqa: F821 F722
|
| 169 |
+
drop_audio_cond, # cfg for cond audio
|
| 170 |
+
drop_text, # cfg for text
|
| 171 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
| 172 |
+
):
|
| 173 |
+
batch, seq_len = x.shape[0], x.shape[1]
|
| 174 |
+
if time.ndim == 0:
|
| 175 |
+
time = time.repeat(batch)
|
| 176 |
+
|
| 177 |
+
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
| 178 |
+
t = self.time_embed(time)
|
| 179 |
+
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
| 180 |
+
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
| 181 |
+
|
| 182 |
+
# postfix time t to input x, [b n d] -> [b n+1 d]
|
| 183 |
+
x = torch.cat([t.unsqueeze(1), x], dim=1) # pack t to x
|
| 184 |
+
if mask is not None:
|
| 185 |
+
mask = F.pad(mask, (1, 0), value=1)
|
| 186 |
+
|
| 187 |
+
rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
|
| 188 |
+
|
| 189 |
+
# flat unet transformer
|
| 190 |
+
skip_connect_type = self.skip_connect_type
|
| 191 |
+
skips = []
|
| 192 |
+
for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
|
| 193 |
+
layer = idx + 1
|
| 194 |
+
|
| 195 |
+
# skip connection logic
|
| 196 |
+
is_first_half = layer <= (self.depth // 2)
|
| 197 |
+
is_later_half = not is_first_half
|
| 198 |
+
|
| 199 |
+
if is_first_half:
|
| 200 |
+
skips.append(x)
|
| 201 |
+
|
| 202 |
+
if is_later_half:
|
| 203 |
+
skip = skips.pop()
|
| 204 |
+
if skip_connect_type == "concat":
|
| 205 |
+
x = torch.cat((x, skip), dim=-1)
|
| 206 |
+
x = maybe_skip_proj(x)
|
| 207 |
+
elif skip_connect_type == "add":
|
| 208 |
+
x = x + skip
|
| 209 |
+
|
| 210 |
+
# attention and feedforward blocks
|
| 211 |
+
x = attn(attn_norm(x), rope=rope, mask=mask) + x
|
| 212 |
+
x = ff(ff_norm(x)) + x
|
| 213 |
+
|
| 214 |
+
assert len(skips) == 0
|
| 215 |
+
|
| 216 |
+
x = self.norm_out(x)[:, 1:, :] # unpack t from x
|
| 217 |
+
|
| 218 |
+
return self.proj_out(x)
|
GPT_SoVITS/f5_tts/model/modules.py
ADDED
|
@@ -0,0 +1,665 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ein notation:
|
| 3 |
+
b - batch
|
| 4 |
+
n - sequence
|
| 5 |
+
nt - text sequence
|
| 6 |
+
nw - raw wave length
|
| 7 |
+
d - dimension
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import math
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import torchaudio
|
| 18 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 19 |
+
from torch import nn
|
| 20 |
+
from x_transformers.x_transformers import apply_rotary_pos_emb
|
| 21 |
+
|
| 22 |
+
# raw wav to mel spec
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
mel_basis_cache = {}
|
| 26 |
+
hann_window_cache = {}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_bigvgan_mel_spectrogram(
|
| 30 |
+
waveform,
|
| 31 |
+
n_fft=1024,
|
| 32 |
+
n_mel_channels=100,
|
| 33 |
+
target_sample_rate=24000,
|
| 34 |
+
hop_length=256,
|
| 35 |
+
win_length=1024,
|
| 36 |
+
fmin=0,
|
| 37 |
+
fmax=None,
|
| 38 |
+
center=False,
|
| 39 |
+
): # Copy from https://github.com/NVIDIA/BigVGAN/tree/main
|
| 40 |
+
device = waveform.device
|
| 41 |
+
key = f"{n_fft}_{n_mel_channels}_{target_sample_rate}_{hop_length}_{win_length}_{fmin}_{fmax}_{device}"
|
| 42 |
+
|
| 43 |
+
if key not in mel_basis_cache:
|
| 44 |
+
mel = librosa_mel_fn(sr=target_sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=fmin, fmax=fmax)
|
| 45 |
+
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device) # TODO: why they need .float()?
|
| 46 |
+
hann_window_cache[key] = torch.hann_window(win_length).to(device)
|
| 47 |
+
|
| 48 |
+
mel_basis = mel_basis_cache[key]
|
| 49 |
+
hann_window = hann_window_cache[key]
|
| 50 |
+
|
| 51 |
+
padding = (n_fft - hop_length) // 2
|
| 52 |
+
waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
|
| 53 |
+
|
| 54 |
+
spec = torch.stft(
|
| 55 |
+
waveform,
|
| 56 |
+
n_fft,
|
| 57 |
+
hop_length=hop_length,
|
| 58 |
+
win_length=win_length,
|
| 59 |
+
window=hann_window,
|
| 60 |
+
center=center,
|
| 61 |
+
pad_mode="reflect",
|
| 62 |
+
normalized=False,
|
| 63 |
+
onesided=True,
|
| 64 |
+
return_complex=True,
|
| 65 |
+
)
|
| 66 |
+
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
| 67 |
+
|
| 68 |
+
mel_spec = torch.matmul(mel_basis, spec)
|
| 69 |
+
mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))
|
| 70 |
+
|
| 71 |
+
return mel_spec
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_vocos_mel_spectrogram(
|
| 75 |
+
waveform,
|
| 76 |
+
n_fft=1024,
|
| 77 |
+
n_mel_channels=100,
|
| 78 |
+
target_sample_rate=24000,
|
| 79 |
+
hop_length=256,
|
| 80 |
+
win_length=1024,
|
| 81 |
+
):
|
| 82 |
+
mel_stft = torchaudio.transforms.MelSpectrogram(
|
| 83 |
+
sample_rate=target_sample_rate,
|
| 84 |
+
n_fft=n_fft,
|
| 85 |
+
win_length=win_length,
|
| 86 |
+
hop_length=hop_length,
|
| 87 |
+
n_mels=n_mel_channels,
|
| 88 |
+
power=1,
|
| 89 |
+
center=True,
|
| 90 |
+
normalized=False,
|
| 91 |
+
norm=None,
|
| 92 |
+
).to(waveform.device)
|
| 93 |
+
if len(waveform.shape) == 3:
|
| 94 |
+
waveform = waveform.squeeze(1) # 'b 1 nw -> b nw'
|
| 95 |
+
|
| 96 |
+
assert len(waveform.shape) == 2
|
| 97 |
+
|
| 98 |
+
mel = mel_stft(waveform)
|
| 99 |
+
mel = mel.clamp(min=1e-5).log()
|
| 100 |
+
return mel
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class MelSpec(nn.Module):
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
n_fft=1024,
|
| 107 |
+
hop_length=256,
|
| 108 |
+
win_length=1024,
|
| 109 |
+
n_mel_channels=100,
|
| 110 |
+
target_sample_rate=24_000,
|
| 111 |
+
mel_spec_type="vocos",
|
| 112 |
+
):
|
| 113 |
+
super().__init__()
|
| 114 |
+
assert mel_spec_type in ["vocos", "bigvgan"], print("We only support two extract mel backend: vocos or bigvgan")
|
| 115 |
+
|
| 116 |
+
self.n_fft = n_fft
|
| 117 |
+
self.hop_length = hop_length
|
| 118 |
+
self.win_length = win_length
|
| 119 |
+
self.n_mel_channels = n_mel_channels
|
| 120 |
+
self.target_sample_rate = target_sample_rate
|
| 121 |
+
|
| 122 |
+
if mel_spec_type == "vocos":
|
| 123 |
+
self.extractor = get_vocos_mel_spectrogram
|
| 124 |
+
elif mel_spec_type == "bigvgan":
|
| 125 |
+
self.extractor = get_bigvgan_mel_spectrogram
|
| 126 |
+
|
| 127 |
+
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
| 128 |
+
|
| 129 |
+
def forward(self, wav):
|
| 130 |
+
if self.dummy.device != wav.device:
|
| 131 |
+
self.to(wav.device)
|
| 132 |
+
|
| 133 |
+
mel = self.extractor(
|
| 134 |
+
waveform=wav,
|
| 135 |
+
n_fft=self.n_fft,
|
| 136 |
+
n_mel_channels=self.n_mel_channels,
|
| 137 |
+
target_sample_rate=self.target_sample_rate,
|
| 138 |
+
hop_length=self.hop_length,
|
| 139 |
+
win_length=self.win_length,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
return mel
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# sinusoidal position embedding
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class SinusPositionEmbedding(nn.Module):
|
| 149 |
+
def __init__(self, dim):
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.dim = dim
|
| 152 |
+
|
| 153 |
+
def forward(self, x, scale=1000):
|
| 154 |
+
device = x.device
|
| 155 |
+
half_dim = self.dim // 2
|
| 156 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 157 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
| 158 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
| 159 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 160 |
+
return emb
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# convolutional position embedding
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class ConvPositionEmbedding(nn.Module):
|
| 167 |
+
def __init__(self, dim, kernel_size=31, groups=16):
|
| 168 |
+
super().__init__()
|
| 169 |
+
assert kernel_size % 2 != 0
|
| 170 |
+
self.conv1d = nn.Sequential(
|
| 171 |
+
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
| 172 |
+
nn.Mish(),
|
| 173 |
+
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
| 174 |
+
nn.Mish(),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
| 178 |
+
if mask is not None:
|
| 179 |
+
mask = mask[..., None]
|
| 180 |
+
x = x.masked_fill(~mask, 0.0)
|
| 181 |
+
|
| 182 |
+
x = x.permute(0, 2, 1)
|
| 183 |
+
x = self.conv1d(x)
|
| 184 |
+
out = x.permute(0, 2, 1)
|
| 185 |
+
|
| 186 |
+
if mask is not None:
|
| 187 |
+
out = out.masked_fill(~mask, 0.0)
|
| 188 |
+
|
| 189 |
+
return out
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# rotary positional embedding related
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
|
| 196 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
| 197 |
+
# has some connection to NTK literature
|
| 198 |
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
| 199 |
+
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
| 200 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
| 201 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 202 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
| 203 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
| 204 |
+
freqs_cos = torch.cos(freqs) # real part
|
| 205 |
+
freqs_sin = torch.sin(freqs) # imaginary part
|
| 206 |
+
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
| 210 |
+
# length = length if isinstance(length, int) else length.max()
|
| 211 |
+
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
|
| 212 |
+
pos = (
|
| 213 |
+
start.unsqueeze(1)
|
| 214 |
+
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
|
| 215 |
+
)
|
| 216 |
+
# avoid extra long error.
|
| 217 |
+
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
| 218 |
+
return pos
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# Global Response Normalization layer (Instance Normalization ?)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class GRN(nn.Module):
|
| 225 |
+
def __init__(self, dim):
|
| 226 |
+
super().__init__()
|
| 227 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
| 228 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
| 229 |
+
|
| 230 |
+
def forward(self, x):
|
| 231 |
+
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
| 232 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
| 233 |
+
return self.gamma * (x * Nx) + self.beta + x
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
|
| 237 |
+
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class ConvNeXtV2Block(nn.Module):
|
| 241 |
+
def __init__(
|
| 242 |
+
self,
|
| 243 |
+
dim: int,
|
| 244 |
+
intermediate_dim: int,
|
| 245 |
+
dilation: int = 1,
|
| 246 |
+
):
|
| 247 |
+
super().__init__()
|
| 248 |
+
padding = (dilation * (7 - 1)) // 2
|
| 249 |
+
self.dwconv = nn.Conv1d(
|
| 250 |
+
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
| 251 |
+
) # depthwise conv
|
| 252 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 253 |
+
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
| 254 |
+
self.act = nn.GELU()
|
| 255 |
+
self.grn = GRN(intermediate_dim)
|
| 256 |
+
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
| 257 |
+
|
| 258 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 259 |
+
residual = x
|
| 260 |
+
x = x.transpose(1, 2) # b n d -> b d n
|
| 261 |
+
x = self.dwconv(x)
|
| 262 |
+
x = x.transpose(1, 2) # b d n -> b n d
|
| 263 |
+
x = self.norm(x)
|
| 264 |
+
x = self.pwconv1(x)
|
| 265 |
+
x = self.act(x)
|
| 266 |
+
x = self.grn(x)
|
| 267 |
+
x = self.pwconv2(x)
|
| 268 |
+
return residual + x
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# AdaLayerNormZero
|
| 272 |
+
# return with modulated x for attn input, and params for later mlp modulation
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
class AdaLayerNormZero(nn.Module):
|
| 276 |
+
def __init__(self, dim):
|
| 277 |
+
super().__init__()
|
| 278 |
+
|
| 279 |
+
self.silu = nn.SiLU()
|
| 280 |
+
self.linear = nn.Linear(dim, dim * 6)
|
| 281 |
+
|
| 282 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 283 |
+
|
| 284 |
+
def forward(self, x, emb=None):
|
| 285 |
+
emb = self.linear(self.silu(emb))
|
| 286 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
| 287 |
+
|
| 288 |
+
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
| 289 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# AdaLayerNormZero for final layer
|
| 293 |
+
# return only with modulated x for attn input, cuz no more mlp modulation
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class AdaLayerNormZero_Final(nn.Module):
|
| 297 |
+
def __init__(self, dim):
|
| 298 |
+
super().__init__()
|
| 299 |
+
|
| 300 |
+
self.silu = nn.SiLU()
|
| 301 |
+
self.linear = nn.Linear(dim, dim * 2)
|
| 302 |
+
|
| 303 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 304 |
+
|
| 305 |
+
def forward(self, x, emb):
|
| 306 |
+
emb = self.linear(self.silu(emb))
|
| 307 |
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
| 308 |
+
|
| 309 |
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
| 310 |
+
return x
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# FeedForward
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class FeedForward(nn.Module):
|
| 317 |
+
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
|
| 318 |
+
super().__init__()
|
| 319 |
+
inner_dim = int(dim * mult)
|
| 320 |
+
dim_out = dim_out if dim_out is not None else dim
|
| 321 |
+
|
| 322 |
+
activation = nn.GELU(approximate=approximate)
|
| 323 |
+
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
|
| 324 |
+
self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
| 325 |
+
|
| 326 |
+
def forward(self, x):
|
| 327 |
+
return self.ff(x)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
# Attention with possible joint part
|
| 331 |
+
# modified from diffusers/src/diffusers/models/attention_processor.py
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class Attention(nn.Module):
|
| 335 |
+
def __init__(
|
| 336 |
+
self,
|
| 337 |
+
processor: JointAttnProcessor | AttnProcessor,
|
| 338 |
+
dim: int,
|
| 339 |
+
heads: int = 8,
|
| 340 |
+
dim_head: int = 64,
|
| 341 |
+
dropout: float = 0.0,
|
| 342 |
+
context_dim: Optional[int] = None, # if not None -> joint attention
|
| 343 |
+
context_pre_only=None,
|
| 344 |
+
):
|
| 345 |
+
super().__init__()
|
| 346 |
+
|
| 347 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 348 |
+
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 349 |
+
|
| 350 |
+
self.processor = processor
|
| 351 |
+
|
| 352 |
+
self.dim = dim
|
| 353 |
+
self.heads = heads
|
| 354 |
+
self.inner_dim = dim_head * heads
|
| 355 |
+
self.dropout = dropout
|
| 356 |
+
|
| 357 |
+
self.context_dim = context_dim
|
| 358 |
+
self.context_pre_only = context_pre_only
|
| 359 |
+
|
| 360 |
+
self.to_q = nn.Linear(dim, self.inner_dim)
|
| 361 |
+
self.to_k = nn.Linear(dim, self.inner_dim)
|
| 362 |
+
self.to_v = nn.Linear(dim, self.inner_dim)
|
| 363 |
+
|
| 364 |
+
if self.context_dim is not None:
|
| 365 |
+
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
| 366 |
+
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
| 367 |
+
if self.context_pre_only is not None:
|
| 368 |
+
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
|
| 369 |
+
|
| 370 |
+
self.to_out = nn.ModuleList([])
|
| 371 |
+
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
| 372 |
+
self.to_out.append(nn.Dropout(dropout))
|
| 373 |
+
|
| 374 |
+
if self.context_pre_only is not None and not self.context_pre_only:
|
| 375 |
+
self.to_out_c = nn.Linear(self.inner_dim, dim)
|
| 376 |
+
|
| 377 |
+
def forward(
|
| 378 |
+
self,
|
| 379 |
+
x: float["b n d"], # noised input x # noqa: F722
|
| 380 |
+
c: float["b n d"] = None, # context c # noqa: F722
|
| 381 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
| 382 |
+
rope=None, # rotary position embedding for x
|
| 383 |
+
c_rope=None, # rotary position embedding for c
|
| 384 |
+
) -> torch.Tensor:
|
| 385 |
+
if c is not None:
|
| 386 |
+
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
|
| 387 |
+
else:
|
| 388 |
+
return self.processor(self, x, mask=mask, rope=rope)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
# Attention processor
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
# from torch.nn.attention import SDPBackend
|
| 395 |
+
# torch.backends.cuda.enable_flash_sdp(True)
|
| 396 |
+
class AttnProcessor:
|
| 397 |
+
def __init__(self):
|
| 398 |
+
pass
|
| 399 |
+
|
| 400 |
+
def __call__(
|
| 401 |
+
self,
|
| 402 |
+
attn: Attention,
|
| 403 |
+
x: float["b n d"], # noised input x # noqa: F722
|
| 404 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
| 405 |
+
rope=None, # rotary position embedding
|
| 406 |
+
) -> torch.FloatTensor:
|
| 407 |
+
batch_size = x.shape[0]
|
| 408 |
+
|
| 409 |
+
# `sample` projections.
|
| 410 |
+
query = attn.to_q(x)
|
| 411 |
+
key = attn.to_k(x)
|
| 412 |
+
value = attn.to_v(x)
|
| 413 |
+
|
| 414 |
+
# apply rotary position embedding
|
| 415 |
+
if rope is not None:
|
| 416 |
+
freqs, xpos_scale = rope
|
| 417 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
| 418 |
+
|
| 419 |
+
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
| 420 |
+
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
| 421 |
+
|
| 422 |
+
# attention
|
| 423 |
+
inner_dim = key.shape[-1]
|
| 424 |
+
head_dim = inner_dim // attn.heads
|
| 425 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 426 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 427 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 428 |
+
|
| 429 |
+
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
| 430 |
+
if mask is not None:
|
| 431 |
+
attn_mask = mask
|
| 432 |
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
| 433 |
+
# print(3433333333,attn_mask.shape)
|
| 434 |
+
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
| 435 |
+
else:
|
| 436 |
+
attn_mask = None
|
| 437 |
+
# with torch.nn.attention.sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
| 438 |
+
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=True):
|
| 439 |
+
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
|
| 440 |
+
# print(torch.backends.cuda.flash_sdp_enabled())
|
| 441 |
+
# print(torch.backends.cuda.mem_efficient_sdp_enabled())
|
| 442 |
+
# print(torch.backends.cuda.math_sdp_enabled())
|
| 443 |
+
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
| 444 |
+
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 445 |
+
x = x.to(query.dtype)
|
| 446 |
+
|
| 447 |
+
# linear proj
|
| 448 |
+
x = attn.to_out[0](x)
|
| 449 |
+
# dropout
|
| 450 |
+
x = attn.to_out[1](x)
|
| 451 |
+
|
| 452 |
+
if mask is not None:
|
| 453 |
+
mask = mask.unsqueeze(-1)
|
| 454 |
+
x = x.masked_fill(~mask, 0.0)
|
| 455 |
+
|
| 456 |
+
return x
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
# Joint Attention processor for MM-DiT
|
| 460 |
+
# modified from diffusers/src/diffusers/models/attention_processor.py
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
class JointAttnProcessor:
|
| 464 |
+
def __init__(self):
|
| 465 |
+
pass
|
| 466 |
+
|
| 467 |
+
def __call__(
|
| 468 |
+
self,
|
| 469 |
+
attn: Attention,
|
| 470 |
+
x: float["b n d"], # noised input x # noqa: F722
|
| 471 |
+
c: float["b nt d"] = None, # context c, here text # noqa: F722
|
| 472 |
+
mask: bool["b n"] | None = None, # noqa: F722
|
| 473 |
+
rope=None, # rotary position embedding for x
|
| 474 |
+
c_rope=None, # rotary position embedding for c
|
| 475 |
+
) -> torch.FloatTensor:
|
| 476 |
+
residual = x
|
| 477 |
+
|
| 478 |
+
batch_size = c.shape[0]
|
| 479 |
+
|
| 480 |
+
# `sample` projections.
|
| 481 |
+
query = attn.to_q(x)
|
| 482 |
+
key = attn.to_k(x)
|
| 483 |
+
value = attn.to_v(x)
|
| 484 |
+
|
| 485 |
+
# `context` projections.
|
| 486 |
+
c_query = attn.to_q_c(c)
|
| 487 |
+
c_key = attn.to_k_c(c)
|
| 488 |
+
c_value = attn.to_v_c(c)
|
| 489 |
+
|
| 490 |
+
# apply rope for context and noised input independently
|
| 491 |
+
if rope is not None:
|
| 492 |
+
freqs, xpos_scale = rope
|
| 493 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
| 494 |
+
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
| 495 |
+
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
| 496 |
+
if c_rope is not None:
|
| 497 |
+
freqs, xpos_scale = c_rope
|
| 498 |
+
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
| 499 |
+
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
| 500 |
+
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
| 501 |
+
|
| 502 |
+
# attention
|
| 503 |
+
query = torch.cat([query, c_query], dim=1)
|
| 504 |
+
key = torch.cat([key, c_key], dim=1)
|
| 505 |
+
value = torch.cat([value, c_value], dim=1)
|
| 506 |
+
|
| 507 |
+
inner_dim = key.shape[-1]
|
| 508 |
+
head_dim = inner_dim // attn.heads
|
| 509 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 510 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 511 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 512 |
+
|
| 513 |
+
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
| 514 |
+
if mask is not None:
|
| 515 |
+
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
|
| 516 |
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
| 517 |
+
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
| 518 |
+
else:
|
| 519 |
+
attn_mask = None
|
| 520 |
+
|
| 521 |
+
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
| 522 |
+
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 523 |
+
x = x.to(query.dtype)
|
| 524 |
+
|
| 525 |
+
# Split the attention outputs.
|
| 526 |
+
x, c = (
|
| 527 |
+
x[:, : residual.shape[1]],
|
| 528 |
+
x[:, residual.shape[1] :],
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
# linear proj
|
| 532 |
+
x = attn.to_out[0](x)
|
| 533 |
+
# dropout
|
| 534 |
+
x = attn.to_out[1](x)
|
| 535 |
+
if not attn.context_pre_only:
|
| 536 |
+
c = attn.to_out_c(c)
|
| 537 |
+
|
| 538 |
+
if mask is not None:
|
| 539 |
+
mask = mask.unsqueeze(-1)
|
| 540 |
+
x = x.masked_fill(~mask, 0.0)
|
| 541 |
+
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
|
| 542 |
+
|
| 543 |
+
return x, c
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
# DiT Block
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
class DiTBlock(nn.Module):
|
| 550 |
+
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
| 551 |
+
super().__init__()
|
| 552 |
+
|
| 553 |
+
self.attn_norm = AdaLayerNormZero(dim)
|
| 554 |
+
self.attn = Attention(
|
| 555 |
+
processor=AttnProcessor(),
|
| 556 |
+
dim=dim,
|
| 557 |
+
heads=heads,
|
| 558 |
+
dim_head=dim_head,
|
| 559 |
+
dropout=dropout,
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 563 |
+
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
| 564 |
+
|
| 565 |
+
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
|
| 566 |
+
# pre-norm & modulation for attention input
|
| 567 |
+
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
| 568 |
+
|
| 569 |
+
# attention
|
| 570 |
+
attn_output = self.attn(x=norm, mask=mask, rope=rope)
|
| 571 |
+
|
| 572 |
+
# process attention output for input x
|
| 573 |
+
x = x + gate_msa.unsqueeze(1) * attn_output
|
| 574 |
+
|
| 575 |
+
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 576 |
+
ff_output = self.ff(norm)
|
| 577 |
+
x = x + gate_mlp.unsqueeze(1) * ff_output
|
| 578 |
+
|
| 579 |
+
return x
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
# MMDiT Block https://arxiv.org/abs/2403.03206
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
class MMDiTBlock(nn.Module):
|
| 586 |
+
r"""
|
| 587 |
+
modified from diffusers/src/diffusers/models/attention.py
|
| 588 |
+
|
| 589 |
+
notes.
|
| 590 |
+
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
|
| 591 |
+
_x: noised input related. (right part)
|
| 592 |
+
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
| 593 |
+
"""
|
| 594 |
+
|
| 595 |
+
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
|
| 596 |
+
super().__init__()
|
| 597 |
+
|
| 598 |
+
self.context_pre_only = context_pre_only
|
| 599 |
+
|
| 600 |
+
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
|
| 601 |
+
self.attn_norm_x = AdaLayerNormZero(dim)
|
| 602 |
+
self.attn = Attention(
|
| 603 |
+
processor=JointAttnProcessor(),
|
| 604 |
+
dim=dim,
|
| 605 |
+
heads=heads,
|
| 606 |
+
dim_head=dim_head,
|
| 607 |
+
dropout=dropout,
|
| 608 |
+
context_dim=dim,
|
| 609 |
+
context_pre_only=context_pre_only,
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
if not context_pre_only:
|
| 613 |
+
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 614 |
+
self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
| 615 |
+
else:
|
| 616 |
+
self.ff_norm_c = None
|
| 617 |
+
self.ff_c = None
|
| 618 |
+
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 619 |
+
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
| 620 |
+
|
| 621 |
+
def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
|
| 622 |
+
# pre-norm & modulation for attention input
|
| 623 |
+
if self.context_pre_only:
|
| 624 |
+
norm_c = self.attn_norm_c(c, t)
|
| 625 |
+
else:
|
| 626 |
+
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
|
| 627 |
+
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
|
| 628 |
+
|
| 629 |
+
# attention
|
| 630 |
+
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
|
| 631 |
+
|
| 632 |
+
# process attention output for context c
|
| 633 |
+
if self.context_pre_only:
|
| 634 |
+
c = None
|
| 635 |
+
else: # if not last layer
|
| 636 |
+
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
| 637 |
+
|
| 638 |
+
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
| 639 |
+
c_ff_output = self.ff_c(norm_c)
|
| 640 |
+
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
|
| 641 |
+
|
| 642 |
+
# process attention output for input x
|
| 643 |
+
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
|
| 644 |
+
|
| 645 |
+
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
|
| 646 |
+
x_ff_output = self.ff_x(norm_x)
|
| 647 |
+
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
|
| 648 |
+
|
| 649 |
+
return c, x
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
# time step conditioning embedding
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
class TimestepEmbedding(nn.Module):
|
| 656 |
+
def __init__(self, dim, freq_embed_dim=256):
|
| 657 |
+
super().__init__()
|
| 658 |
+
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
| 659 |
+
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
| 660 |
+
|
| 661 |
+
def forward(self, timestep: float["b"]): # noqa: F821
|
| 662 |
+
time_hidden = self.time_embed(timestep)
|
| 663 |
+
time_hidden = time_hidden.to(timestep.dtype)
|
| 664 |
+
time = self.time_mlp(time_hidden) # b d
|
| 665 |
+
return time
|
GPT_SoVITS/feature_extractor/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import cnhubert
|
| 2 |
+
|
| 3 |
+
content_module_map = {"cnhubert": cnhubert}
|
GPT_SoVITS/feature_extractor/cnhubert.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from transformers import (
|
| 7 |
+
HubertModel,
|
| 8 |
+
Wav2Vec2FeatureExtractor,
|
| 9 |
+
)
|
| 10 |
+
from transformers import logging as tf_logging
|
| 11 |
+
|
| 12 |
+
tf_logging.set_verbosity_error()
|
| 13 |
+
|
| 14 |
+
logging.getLogger("numba").setLevel(logging.WARNING)
|
| 15 |
+
|
| 16 |
+
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CNHubert(nn.Module):
|
| 20 |
+
def __init__(self, base_path: str = ""):
|
| 21 |
+
super().__init__()
|
| 22 |
+
if not base_path:
|
| 23 |
+
base_path = cnhubert_base_path
|
| 24 |
+
if os.path.exists(base_path):
|
| 25 |
+
...
|
| 26 |
+
else:
|
| 27 |
+
raise FileNotFoundError(base_path)
|
| 28 |
+
self.model = HubertModel.from_pretrained(base_path, local_files_only=True)
|
| 29 |
+
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(base_path, local_files_only=True)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
|
| 33 |
+
feats = self.model(input_values)["last_hidden_state"]
|
| 34 |
+
return feats
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_model():
|
| 38 |
+
model = CNHubert()
|
| 39 |
+
model.eval()
|
| 40 |
+
return model
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_content(hmodel, wav_16k_tensor):
|
| 44 |
+
with torch.no_grad():
|
| 45 |
+
feats = hmodel(wav_16k_tensor)
|
| 46 |
+
return feats.transpose(1, 2)
|
GPT_SoVITS/inference_webui.py
ADDED
|
@@ -0,0 +1,1104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import contextlib
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import shutil
|
| 7 |
+
import traceback
|
| 8 |
+
import warnings
|
| 9 |
+
import zipfile
|
| 10 |
+
from functools import partial
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from time import time as ttime
|
| 13 |
+
from typing import Any
|
| 14 |
+
|
| 15 |
+
import gradio as gr
|
| 16 |
+
import librosa
|
| 17 |
+
import numpy as np
|
| 18 |
+
import spaces
|
| 19 |
+
import torch
|
| 20 |
+
import torchaudio
|
| 21 |
+
from huggingface_hub import hf_hub_download
|
| 22 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
| 23 |
+
|
| 24 |
+
from config import (
|
| 25 |
+
change_choices,
|
| 26 |
+
get_dtype,
|
| 27 |
+
get_weights_names,
|
| 28 |
+
)
|
| 29 |
+
from config import (
|
| 30 |
+
infer_device as default_device,
|
| 31 |
+
)
|
| 32 |
+
from GPT_SoVITS.Accelerate import PyTorch, T2SEngineProtocol, T2SRequest, backends
|
| 33 |
+
from GPT_SoVITS.Accelerate.logger import console
|
| 34 |
+
from GPT_SoVITS.feature_extractor import cnhubert
|
| 35 |
+
from GPT_SoVITS.module.mel_processing import mel_spectrogram_torch, spectrogram_torch
|
| 36 |
+
from GPT_SoVITS.module.models import SynthesizerTrn
|
| 37 |
+
from GPT_SoVITS.process_ckpt import inspect_version
|
| 38 |
+
from GPT_SoVITS.sv import SV
|
| 39 |
+
from GPT_SoVITS.text import cleaned_text_to_sequence
|
| 40 |
+
from GPT_SoVITS.text.cleaner import clean_text
|
| 41 |
+
from GPT_SoVITS.text.LangSegmenter import LangSegmenter
|
| 42 |
+
from tools.assets import css, js, top_html
|
| 43 |
+
from tools.i18n.i18n import I18nAuto, scan_language_list
|
| 44 |
+
from tools.my_utils import DictToAttrRecursive
|
| 45 |
+
|
| 46 |
+
warnings.filterwarnings(
|
| 47 |
+
"ignore", message="MPS: The constant padding of more than 3 dimensions is not currently supported natively."
|
| 48 |
+
)
|
| 49 |
+
warnings.filterwarnings("ignore", message=".*ComplexHalf support is experimental.*")
|
| 50 |
+
|
| 51 |
+
logging.getLogger("markdown_it").setLevel(logging.ERROR)
|
| 52 |
+
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
| 53 |
+
logging.getLogger("httpcore").setLevel(logging.ERROR)
|
| 54 |
+
logging.getLogger("httpx").setLevel(logging.ERROR)
|
| 55 |
+
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
| 56 |
+
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
| 57 |
+
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
| 58 |
+
logging.getLogger("multipart.multipart").setLevel(logging.ERROR)
|
| 59 |
+
|
| 60 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 61 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def install():
|
| 65 |
+
base = Path("GPT_SoVITS")
|
| 66 |
+
zip_path = hf_hub_download("XXXXRT/GPT-SoVITS-Pretrained", "pretrained_models.zip", repo_type="model")
|
| 67 |
+
tmp = base / "tmp_unzip"
|
| 68 |
+
if tmp.exists():
|
| 69 |
+
shutil.rmtree(tmp)
|
| 70 |
+
with zipfile.ZipFile(zip_path) as zf:
|
| 71 |
+
zf.extractall(tmp)
|
| 72 |
+
folder = next(tmp.iterdir())
|
| 73 |
+
shutil.move(str(folder), base / folder.name)
|
| 74 |
+
shutil.rmtree(tmp)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
install()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
_LANG_RE = re.compile(r"^[a-z]{2}[_-][A-Z]{2}$")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def lang_type(text: str) -> str:
|
| 84 |
+
if text == "Auto":
|
| 85 |
+
return text
|
| 86 |
+
if not _LANG_RE.match(text):
|
| 87 |
+
raise argparse.ArgumentTypeError(f"Unspported Format: {text}, Expected ll_CC/ll-CC")
|
| 88 |
+
ll, cc = re.split(r"[_-]", text)
|
| 89 |
+
language = f"{ll}_{cc}"
|
| 90 |
+
if language in scan_language_list():
|
| 91 |
+
return language
|
| 92 |
+
else:
|
| 93 |
+
return "Auto"
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 97 |
+
p = argparse.ArgumentParser(
|
| 98 |
+
prog="inference_webui",
|
| 99 |
+
description=f"python -s -m GPT_SoVITS.inference_webui zh_CN -b {backends[-1]}",
|
| 100 |
+
)
|
| 101 |
+
p.add_argument(
|
| 102 |
+
"language",
|
| 103 |
+
nargs="?",
|
| 104 |
+
default="Auto",
|
| 105 |
+
type=lang_type,
|
| 106 |
+
help="Language Code, Such as zh_CN, en-US",
|
| 107 |
+
)
|
| 108 |
+
p.add_argument(
|
| 109 |
+
"--backends",
|
| 110 |
+
"-b",
|
| 111 |
+
choices=backends,
|
| 112 |
+
default=backends[-1],
|
| 113 |
+
help="AR Inference Backend",
|
| 114 |
+
required=False,
|
| 115 |
+
)
|
| 116 |
+
p.add_argument(
|
| 117 |
+
"--device",
|
| 118 |
+
"-d",
|
| 119 |
+
default=str(default_device),
|
| 120 |
+
help="Inference Device",
|
| 121 |
+
required=False,
|
| 122 |
+
)
|
| 123 |
+
p.add_argument(
|
| 124 |
+
"--port",
|
| 125 |
+
"-p",
|
| 126 |
+
default=9872,
|
| 127 |
+
type=int,
|
| 128 |
+
help="WebUI Binding Port",
|
| 129 |
+
required=False,
|
| 130 |
+
)
|
| 131 |
+
p.add_argument(
|
| 132 |
+
"--share",
|
| 133 |
+
"-s",
|
| 134 |
+
default=False,
|
| 135 |
+
action="store_true",
|
| 136 |
+
help="Gradio Share Link",
|
| 137 |
+
required=False,
|
| 138 |
+
)
|
| 139 |
+
p.add_argument(
|
| 140 |
+
"--cnhubert",
|
| 141 |
+
default="GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
| 142 |
+
help="CNHuBERT Pretrain",
|
| 143 |
+
required=False,
|
| 144 |
+
)
|
| 145 |
+
p.add_argument(
|
| 146 |
+
"--bert",
|
| 147 |
+
default="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
| 148 |
+
help="BERT Pretrain",
|
| 149 |
+
required=False,
|
| 150 |
+
)
|
| 151 |
+
p.add_argument(
|
| 152 |
+
"--gpt",
|
| 153 |
+
default="",
|
| 154 |
+
help="GPT Model",
|
| 155 |
+
required=False,
|
| 156 |
+
)
|
| 157 |
+
p.add_argument(
|
| 158 |
+
"--sovits",
|
| 159 |
+
default="",
|
| 160 |
+
help="SoVITS Model",
|
| 161 |
+
required=False,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
return p
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
args = build_parser().parse_args()
|
| 168 |
+
|
| 169 |
+
hps: Any = None
|
| 170 |
+
vq_model: SynthesizerTrn | None = None
|
| 171 |
+
t2s_engine: T2SEngineProtocol | None = None
|
| 172 |
+
|
| 173 |
+
version = model_version = "v2"
|
| 174 |
+
cnhubert_base_path = str(args.cnhubert)
|
| 175 |
+
bert_path = str(args.bert)
|
| 176 |
+
infer_ttswebui = int(args.port)
|
| 177 |
+
is_share = bool(args.share)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
i18n = I18nAuto(language=args.language)
|
| 181 |
+
ar_backend: str = args.backends
|
| 182 |
+
change_choices_i18n = partial(change_choices, i18n=i18n)
|
| 183 |
+
|
| 184 |
+
SoVITS_names, GPT_names = get_weights_names(i18n)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
dict_language_v1 = {
|
| 188 |
+
i18n("中文"): "all_zh", # 全部按中文识别
|
| 189 |
+
i18n("英文"): "en", # 全部按英文识别
|
| 190 |
+
i18n("日文"): "all_ja", # 全部按日文识别
|
| 191 |
+
i18n("中英混合"): "zh", # 按中英混合识别
|
| 192 |
+
i18n("日英混合"): "ja", # 按日英混合识别
|
| 193 |
+
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
| 194 |
+
}
|
| 195 |
+
dict_language_v2 = {
|
| 196 |
+
i18n("中文"): "all_zh", # 全部按中文识别
|
| 197 |
+
i18n("英文"): "en", # 全部按英文识别
|
| 198 |
+
i18n("日文"): "all_ja", # 全部按日文识别
|
| 199 |
+
i18n("粤语"): "all_yue", # 全部按粤语识别
|
| 200 |
+
i18n("韩文"): "all_ko", # 全部按韩文识别
|
| 201 |
+
i18n("中英混合"): "zh",
|
| 202 |
+
i18n("日英混合"): "ja",
|
| 203 |
+
i18n("粤英混合"): "yue",
|
| 204 |
+
i18n("韩英混合"): "ko",
|
| 205 |
+
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
| 206 |
+
i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
|
| 207 |
+
}
|
| 208 |
+
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
| 209 |
+
|
| 210 |
+
punctuation = set(["!", "?", "…", ",", ".", "-", " "])
|
| 211 |
+
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…"}
|
| 212 |
+
v3v4set = {"v3", "v4"}
|
| 213 |
+
|
| 214 |
+
infer_device = torch.device(args.device)
|
| 215 |
+
device = infer_device if infer_device.type == "cuda" else torch.device("cpu")
|
| 216 |
+
|
| 217 |
+
dtype = get_dtype(device.index)
|
| 218 |
+
is_half = dtype == torch.float16
|
| 219 |
+
|
| 220 |
+
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
| 221 |
+
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path).to(infer_device, dtype)
|
| 222 |
+
|
| 223 |
+
cnhubert.cnhubert_base_path = cnhubert_base_path
|
| 224 |
+
ssl_model = cnhubert.get_model().to(infer_device, dtype)
|
| 225 |
+
|
| 226 |
+
spec_min = -12
|
| 227 |
+
spec_max = 2
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def norm_spec(x):
|
| 231 |
+
return (x - spec_min) / (spec_max - spec_min) * 2 - 1
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def denorm_spec(x):
|
| 235 |
+
return (x + 1) / 2 * (spec_max - spec_min) + spec_min
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def mel_fn(x):
|
| 239 |
+
return mel_spectrogram_torch(
|
| 240 |
+
y=x,
|
| 241 |
+
n_fft=1024,
|
| 242 |
+
num_mels=100,
|
| 243 |
+
sampling_rate=24000,
|
| 244 |
+
hop_size=256,
|
| 245 |
+
win_size=1024,
|
| 246 |
+
fmin=0,
|
| 247 |
+
fmax=None,
|
| 248 |
+
center=False,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def mel_fn_v4(x):
|
| 253 |
+
return mel_spectrogram_torch(
|
| 254 |
+
y=x,
|
| 255 |
+
n_fft=1280,
|
| 256 |
+
num_mels=100,
|
| 257 |
+
sampling_rate=32000,
|
| 258 |
+
hop_size=320,
|
| 259 |
+
win_size=1280,
|
| 260 |
+
fmin=0,
|
| 261 |
+
fmax=None,
|
| 262 |
+
center=False,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
gpt_path = str(args.gpt) or GPT_names[0][-1]
|
| 267 |
+
sovits_path = str(args.sovits) or SoVITS_names[0][-1]
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def get_bert_feature(text, word2ph):
|
| 271 |
+
inputs = tokenizer(text, return_tensors="pt")
|
| 272 |
+
for i in inputs:
|
| 273 |
+
inputs[i] = inputs[i].to(infer_device)
|
| 274 |
+
res = bert_model(**inputs, output_hidden_states=True)
|
| 275 |
+
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
| 276 |
+
|
| 277 |
+
assert len(word2ph) == len(text)
|
| 278 |
+
phone_level_feature = []
|
| 279 |
+
for i in range(len(word2ph)):
|
| 280 |
+
repeat_feature = res[i].repeat(word2ph[i], 1)
|
| 281 |
+
phone_level_feature.append(repeat_feature)
|
| 282 |
+
phone_level_feature_t = torch.cat(phone_level_feature, dim=0)
|
| 283 |
+
return phone_level_feature_t.T
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
| 287 |
+
global vq_model, hps, version, model_version, dict_language
|
| 288 |
+
model_version, version, is_lora, hps, dict_s2 = inspect_version(sovits_path)
|
| 289 |
+
print(sovits_path, version, model_version, is_lora)
|
| 290 |
+
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
| 291 |
+
visible_sample_steps = visible_inp_refs = None
|
| 292 |
+
if prompt_language is not None and text_language is not None:
|
| 293 |
+
if prompt_language in list(dict_language.keys()):
|
| 294 |
+
prompt_text_update, prompt_language_update = gr.skip(), gr.update(choices=list(dict_language.keys()))
|
| 295 |
+
else:
|
| 296 |
+
prompt_text_update = gr.update(value="")
|
| 297 |
+
prompt_language_update = gr.update(value=i18n("中文"), choices=list(dict_language.keys()))
|
| 298 |
+
if text_language in list(dict_language.keys()):
|
| 299 |
+
text_update, text_language_update = gr.skip(), gr.skip()
|
| 300 |
+
else:
|
| 301 |
+
text_update = gr.update(value="")
|
| 302 |
+
text_language_update = gr.update(value=i18n("中文"), choices=list(dict_language.keys()))
|
| 303 |
+
|
| 304 |
+
if model_version in v3v4set:
|
| 305 |
+
visible_sample_steps = True
|
| 306 |
+
visible_inp_refs = False
|
| 307 |
+
else:
|
| 308 |
+
visible_sample_steps = False
|
| 309 |
+
visible_inp_refs = True
|
| 310 |
+
yield (
|
| 311 |
+
prompt_text_update,
|
| 312 |
+
prompt_language_update,
|
| 313 |
+
text_update,
|
| 314 |
+
text_language_update,
|
| 315 |
+
gr.update(
|
| 316 |
+
visible=visible_sample_steps,
|
| 317 |
+
value=32 if model_version == "v3" else 8,
|
| 318 |
+
choices=[4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32],
|
| 319 |
+
),
|
| 320 |
+
gr.update(visible=visible_inp_refs),
|
| 321 |
+
gr.update(value=False, interactive=True if model_version not in v3v4set else False),
|
| 322 |
+
gr.update(visible=True if model_version == "v3" else False),
|
| 323 |
+
gr.update(value=i18n("模型加载中,请等待"), interactive=False),
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
hps = DictToAttrRecursive(hps)
|
| 327 |
+
hps.model.semantic_frame_rate = "25hz"
|
| 328 |
+
hps.model.version = model_version
|
| 329 |
+
if model_version not in v3v4set:
|
| 330 |
+
vq_model = SynthesizerTrn(
|
| 331 |
+
hps.data.filter_length // 2 + 1,
|
| 332 |
+
hps.train.segment_size // hps.data.hop_length,
|
| 333 |
+
n_speakers=hps.data.n_speakers,
|
| 334 |
+
**hps.model,
|
| 335 |
+
)
|
| 336 |
+
else:
|
| 337 |
+
raise RuntimeError("Unsupported model version")
|
| 338 |
+
|
| 339 |
+
if "pretrained" not in sovits_path:
|
| 340 |
+
if hasattr(vq_model, "enc_q"):
|
| 341 |
+
del vq_model.enc_q
|
| 342 |
+
|
| 343 |
+
if is_lora is False:
|
| 344 |
+
console.print(f">> loading sovits_{model_version}", vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
| 345 |
+
else:
|
| 346 |
+
RuntimeError("Unsupported model version")
|
| 347 |
+
|
| 348 |
+
vq_model = vq_model.to(infer_device, dtype)
|
| 349 |
+
|
| 350 |
+
yield (
|
| 351 |
+
gr.skip(),
|
| 352 |
+
gr.skip(),
|
| 353 |
+
gr.skip(),
|
| 354 |
+
gr.skip(),
|
| 355 |
+
gr.skip(),
|
| 356 |
+
gr.skip(),
|
| 357 |
+
gr.skip(),
|
| 358 |
+
gr.skip(),
|
| 359 |
+
gr.update(value=i18n("合成语音"), interactive=True),
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
with contextlib.suppress(UnboundLocalError):
|
| 364 |
+
next(change_sovits_weights(sovits_path))
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def change_gpt_weights(gpt_path):
|
| 368 |
+
global t2s_engine, config
|
| 369 |
+
|
| 370 |
+
t2s_engine = PyTorch.T2SEngineTorch(
|
| 371 |
+
PyTorch.T2SEngineTorch.load_decoder(Path(gpt_path), backend=ar_backend),
|
| 372 |
+
device,
|
| 373 |
+
dtype=dtype,
|
| 374 |
+
)
|
| 375 |
+
# t2s_engine.decoder_model.compile()
|
| 376 |
+
total = sum(p.numel() for p in t2s_engine.decoder_model.parameters())
|
| 377 |
+
console.print(">> Number of parameter: %.2fM" % (total / 1e6))
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
change_gpt_weights(gpt_path)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
sv_cn_model = SV(infer_device, is_half)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
resample_transform_dict = {}
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def resample(audio_tensor, sr0, sr1, device):
|
| 390 |
+
global resample_transform_dict
|
| 391 |
+
key = f"{sr0}-{sr1}-{device}"
|
| 392 |
+
if key not in resample_transform_dict:
|
| 393 |
+
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
|
| 394 |
+
return resample_transform_dict[key](audio_tensor)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def get_spepc(hps, filename, dtype, device, is_v2pro=False):
|
| 398 |
+
sr1 = int(hps.data.sampling_rate)
|
| 399 |
+
audio, sr0 = torchaudio.load_with_torchcodec(filename)
|
| 400 |
+
audio = audio.to(device)
|
| 401 |
+
|
| 402 |
+
if sr0 != sr1:
|
| 403 |
+
audio = resample(audio, sr0, sr1, device)
|
| 404 |
+
if audio.shape[0] > 1:
|
| 405 |
+
audio = audio.mean(0).unsqueeze(0)
|
| 406 |
+
|
| 407 |
+
maxx = float(audio.abs().max())
|
| 408 |
+
if maxx > 1:
|
| 409 |
+
audio /= min(2, maxx)
|
| 410 |
+
spec = spectrogram_torch(
|
| 411 |
+
audio,
|
| 412 |
+
hps.data.filter_length,
|
| 413 |
+
hps.data.sampling_rate,
|
| 414 |
+
hps.data.hop_length,
|
| 415 |
+
hps.data.win_length,
|
| 416 |
+
center=False,
|
| 417 |
+
)
|
| 418 |
+
spec = spec.to(dtype)
|
| 419 |
+
if is_v2pro is True:
|
| 420 |
+
audio = resample(audio, sr1, 16000, device).to(dtype)
|
| 421 |
+
return spec, audio
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def clean_text_inf(text, language, version):
|
| 425 |
+
language = language.replace("all_", "")
|
| 426 |
+
phones, word2ph, norm_text = clean_text(text, language, version)
|
| 427 |
+
phones = cleaned_text_to_sequence(phones, version)
|
| 428 |
+
return phones, word2ph, norm_text
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def get_bert_inf(phones, word2ph, norm_text, language):
|
| 432 |
+
language = language.replace("all_", "")
|
| 433 |
+
if language == "zh":
|
| 434 |
+
bert = get_bert_feature(norm_text, word2ph).to(device) # .to(dtype)
|
| 435 |
+
else:
|
| 436 |
+
bert = torch.zeros(
|
| 437 |
+
(1024, len(phones)),
|
| 438 |
+
dtype=torch.float16 if is_half is True else torch.float32,
|
| 439 |
+
).to(device)
|
| 440 |
+
|
| 441 |
+
return bert
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def get_first(text):
|
| 445 |
+
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
| 446 |
+
text = re.split(pattern, text)[0].strip()
|
| 447 |
+
return text
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def get_phones_and_bert(text, language, version, final=False):
|
| 451 |
+
text = re.sub(r" {2,}", " ", text)
|
| 452 |
+
textlist = []
|
| 453 |
+
langlist = []
|
| 454 |
+
if language == "all_zh":
|
| 455 |
+
for tmp in LangSegmenter.getTexts(text, "zh"):
|
| 456 |
+
langlist.append(tmp["lang"])
|
| 457 |
+
textlist.append(tmp["text"])
|
| 458 |
+
elif language == "all_yue":
|
| 459 |
+
for tmp in LangSegmenter.getTexts(text, "zh"):
|
| 460 |
+
if tmp["lang"] == "zh":
|
| 461 |
+
tmp["lang"] = "yue"
|
| 462 |
+
langlist.append(tmp["lang"])
|
| 463 |
+
textlist.append(tmp["text"])
|
| 464 |
+
elif language == "all_ja":
|
| 465 |
+
for tmp in LangSegmenter.getTexts(text, "ja"):
|
| 466 |
+
langlist.append(tmp["lang"])
|
| 467 |
+
textlist.append(tmp["text"])
|
| 468 |
+
elif language == "all_ko":
|
| 469 |
+
for tmp in LangSegmenter.getTexts(text, "ko"):
|
| 470 |
+
langlist.append(tmp["lang"])
|
| 471 |
+
textlist.append(tmp["text"])
|
| 472 |
+
elif language == "en":
|
| 473 |
+
langlist.append("en")
|
| 474 |
+
textlist.append(text)
|
| 475 |
+
elif language == "auto":
|
| 476 |
+
for tmp in LangSegmenter.getTexts(text):
|
| 477 |
+
langlist.append(tmp["lang"])
|
| 478 |
+
textlist.append(tmp["text"])
|
| 479 |
+
elif language == "auto_yue":
|
| 480 |
+
for tmp in LangSegmenter.getTexts(text):
|
| 481 |
+
if tmp["lang"] == "zh":
|
| 482 |
+
tmp["lang"] = "yue"
|
| 483 |
+
langlist.append(tmp["lang"])
|
| 484 |
+
textlist.append(tmp["text"])
|
| 485 |
+
else:
|
| 486 |
+
for tmp in LangSegmenter.getTexts(text):
|
| 487 |
+
if langlist:
|
| 488 |
+
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
|
| 489 |
+
textlist[-1] += tmp["text"]
|
| 490 |
+
continue
|
| 491 |
+
if tmp["lang"] == "en":
|
| 492 |
+
langlist.append(tmp["lang"])
|
| 493 |
+
else:
|
| 494 |
+
# 因无法区别中日韩文汉字,以用户输入为准
|
| 495 |
+
langlist.append(language)
|
| 496 |
+
textlist.append(tmp["text"])
|
| 497 |
+
print(textlist)
|
| 498 |
+
print(langlist)
|
| 499 |
+
phones_list = []
|
| 500 |
+
bert_list = []
|
| 501 |
+
norm_text_list = []
|
| 502 |
+
for i in range(len(textlist)):
|
| 503 |
+
lang = langlist[i]
|
| 504 |
+
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
|
| 505 |
+
bert = get_bert_inf(phones, word2ph, norm_text, lang)
|
| 506 |
+
phones_list.append(phones)
|
| 507 |
+
norm_text_list.append(norm_text)
|
| 508 |
+
bert_list.append(bert)
|
| 509 |
+
bert = torch.cat(bert_list, dim=1)
|
| 510 |
+
phones = sum(phones_list, [])
|
| 511 |
+
norm_text = "".join(norm_text_list)
|
| 512 |
+
|
| 513 |
+
if not final and len(phones) < 6:
|
| 514 |
+
return get_phones_and_bert("." + text, language, version, final=True)
|
| 515 |
+
|
| 516 |
+
return phones, bert.to(dtype), norm_text
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def merge_short_text_in_array(texts, threshold):
|
| 520 |
+
if (len(texts)) < 2:
|
| 521 |
+
return texts
|
| 522 |
+
result = []
|
| 523 |
+
text = ""
|
| 524 |
+
for ele in texts:
|
| 525 |
+
text += ele
|
| 526 |
+
if len(text) >= threshold:
|
| 527 |
+
result.append(text)
|
| 528 |
+
text = ""
|
| 529 |
+
if len(text) > 0:
|
| 530 |
+
if len(result) == 0:
|
| 531 |
+
result.append(text)
|
| 532 |
+
else:
|
| 533 |
+
result[len(result) - 1] += text
|
| 534 |
+
return result
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
sr_model = None
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
cache: dict[int, Any] = {}
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
@spaces.GPU
|
| 544 |
+
def get_tts_wav(
|
| 545 |
+
ref_wav_path,
|
| 546 |
+
prompt_text,
|
| 547 |
+
prompt_language,
|
| 548 |
+
text,
|
| 549 |
+
text_language,
|
| 550 |
+
how_to_cut=i18n("不切"),
|
| 551 |
+
top_k=20,
|
| 552 |
+
top_p=0.6,
|
| 553 |
+
temperature=0.6,
|
| 554 |
+
ref_free=False,
|
| 555 |
+
speed=1,
|
| 556 |
+
if_freeze=False,
|
| 557 |
+
inp_refs=None,
|
| 558 |
+
sample_steps=8,
|
| 559 |
+
if_sr=False,
|
| 560 |
+
pause_second=0.3,
|
| 561 |
+
):
|
| 562 |
+
torch.set_grad_enabled(False)
|
| 563 |
+
ttfb_time = ttime()
|
| 564 |
+
|
| 565 |
+
if ref_wav_path:
|
| 566 |
+
pass
|
| 567 |
+
else:
|
| 568 |
+
gr.Warning(i18n("请上传参考音频"))
|
| 569 |
+
if text:
|
| 570 |
+
pass
|
| 571 |
+
else:
|
| 572 |
+
gr.Warning(i18n("请填入推理文本"))
|
| 573 |
+
t = []
|
| 574 |
+
if prompt_text is None or len(prompt_text) == 0:
|
| 575 |
+
ref_free = True
|
| 576 |
+
if model_version in v3v4set:
|
| 577 |
+
ref_free = False # s2v3暂不支持ref_free
|
| 578 |
+
t0 = ttime()
|
| 579 |
+
prompt_language = dict_language[prompt_language]
|
| 580 |
+
text_language = dict_language[text_language]
|
| 581 |
+
|
| 582 |
+
if not ref_free:
|
| 583 |
+
prompt_text = prompt_text.strip("\n")
|
| 584 |
+
if prompt_text[-1] not in splits:
|
| 585 |
+
prompt_text += "。" if prompt_language != "en" else "."
|
| 586 |
+
print(">>", i18n("实际输入的参考文本:"), prompt_text)
|
| 587 |
+
text = text.strip("\n")
|
| 588 |
+
# if (text[0] not in splits and len(get_first(text)) < 4): text = "。" + text if text_language != "en" else "." + text
|
| 589 |
+
|
| 590 |
+
print(">>", i18n("实际输入的目标文本:"), text)
|
| 591 |
+
zero_wav = np.zeros(
|
| 592 |
+
int(hps.data.sampling_rate * pause_second),
|
| 593 |
+
dtype=np.float16 if is_half is True else np.float32,
|
| 594 |
+
)
|
| 595 |
+
zero_wav_torch = torch.from_numpy(zero_wav)
|
| 596 |
+
if is_half is True:
|
| 597 |
+
zero_wav_torch = zero_wav_torch.half().to(infer_device)
|
| 598 |
+
else:
|
| 599 |
+
zero_wav_torch = zero_wav_torch.to(infer_device)
|
| 600 |
+
if not ref_free:
|
| 601 |
+
assert vq_model
|
| 602 |
+
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
| 603 |
+
if wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000:
|
| 604 |
+
gr.Warning(i18n("参考音频在3~10秒范围外,请更换!"))
|
| 605 |
+
raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
|
| 606 |
+
wav16k_t = torch.from_numpy(wav16k)
|
| 607 |
+
if is_half is True:
|
| 608 |
+
wav16k_t = wav16k_t.half().to(infer_device)
|
| 609 |
+
else:
|
| 610 |
+
wav16k_t = wav16k_t.to(infer_device)
|
| 611 |
+
wav16k_t = torch.cat([wav16k_t, zero_wav_torch])
|
| 612 |
+
ssl_content = ssl_model.model(wav16k_t.unsqueeze(0))["last_hidden_state"].transpose(1, 2) # .float()
|
| 613 |
+
codes = vq_model.extract_latent(ssl_content)
|
| 614 |
+
prompt_semantic = codes[0, 0]
|
| 615 |
+
prompt = prompt_semantic.unsqueeze(0).to(device)
|
| 616 |
+
else:
|
| 617 |
+
prompt = torch.zeros((1, 0)).to(device, torch.int32)
|
| 618 |
+
|
| 619 |
+
t1 = ttime()
|
| 620 |
+
t.append(t1 - t0)
|
| 621 |
+
|
| 622 |
+
if how_to_cut == i18n("凑四句一切"):
|
| 623 |
+
text = cut1(text)
|
| 624 |
+
elif how_to_cut == i18n("凑50字一切"):
|
| 625 |
+
text = cut2(text)
|
| 626 |
+
elif how_to_cut == i18n("按中文句号。切"):
|
| 627 |
+
text = cut3(text)
|
| 628 |
+
elif how_to_cut == i18n("按英文句号.切"):
|
| 629 |
+
text = cut4(text)
|
| 630 |
+
elif how_to_cut == i18n("按标点符号切"):
|
| 631 |
+
text = cut5(text)
|
| 632 |
+
while "\n\n" in text:
|
| 633 |
+
text = text.replace("\n\n", "\n")
|
| 634 |
+
texts = text.split("\n")
|
| 635 |
+
texts = process_text(texts)
|
| 636 |
+
texts = merge_short_text_in_array(texts, 5)
|
| 637 |
+
audio_opt = []
|
| 638 |
+
# s2v3暂不支持ref_free
|
| 639 |
+
if not ref_free:
|
| 640 |
+
phones1, bert1, _ = get_phones_and_bert(prompt_text, prompt_language, version)
|
| 641 |
+
else:
|
| 642 |
+
phones1, bert1 = [], torch.zeros(1024, 0).to(device, dtype)
|
| 643 |
+
|
| 644 |
+
infer_len: list[int] = []
|
| 645 |
+
infer_time: list[float] = []
|
| 646 |
+
assert vq_model
|
| 647 |
+
|
| 648 |
+
for i_text, text in enumerate(texts):
|
| 649 |
+
# 解决输入目标文本的空行导致报错的问题
|
| 650 |
+
if len(text.strip()) == 0:
|
| 651 |
+
continue
|
| 652 |
+
if text[-1] not in splits:
|
| 653 |
+
text += "。" if text_language != "en" else "."
|
| 654 |
+
print(">>", i18n("实际输入的目标文本(每句):"), text)
|
| 655 |
+
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
|
| 656 |
+
print(">>", i18n("前端处理后的文本(每句):"), norm_text2)
|
| 657 |
+
|
| 658 |
+
bert = torch.cat([bert1, bert2], 1)
|
| 659 |
+
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
| 660 |
+
|
| 661 |
+
bert = bert.to(device).unsqueeze(0)
|
| 662 |
+
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
| 663 |
+
|
| 664 |
+
t2 = ttime()
|
| 665 |
+
if i_text in cache and if_freeze is True:
|
| 666 |
+
pred_semantic = cache[i_text]
|
| 667 |
+
else:
|
| 668 |
+
t2s_request = T2SRequest(
|
| 669 |
+
[all_phoneme_ids.squeeze(0)],
|
| 670 |
+
all_phoneme_len,
|
| 671 |
+
prompt,
|
| 672 |
+
[bert.squeeze(0)],
|
| 673 |
+
valid_length=1,
|
| 674 |
+
top_k=top_k,
|
| 675 |
+
top_p=top_p,
|
| 676 |
+
temperature=temperature,
|
| 677 |
+
early_stop_num=1500,
|
| 678 |
+
use_cuda_graph=torch.cuda.is_available(),
|
| 679 |
+
# debug=True,
|
| 680 |
+
)
|
| 681 |
+
assert t2s_engine
|
| 682 |
+
t2s_result = t2s_engine.generate(t2s_request)
|
| 683 |
+
if t2s_result.exception is not None:
|
| 684 |
+
console.print(t2s_result.traceback)
|
| 685 |
+
raise RuntimeError()
|
| 686 |
+
pred_semantic_list = t2s_result.result
|
| 687 |
+
assert pred_semantic_list, t2s_result.traceback
|
| 688 |
+
pred_semantic = pred_semantic_list[0].unsqueeze(0).to(infer_device)
|
| 689 |
+
infer_len.append(pred_semantic.shape[-1])
|
| 690 |
+
infer_time.append(t2s_result.infer_speed[-1])
|
| 691 |
+
|
| 692 |
+
cache[i_text] = pred_semantic
|
| 693 |
+
t3 = ttime()
|
| 694 |
+
is_v2pro = model_version in {"v2Pro", "v2ProPlus"}
|
| 695 |
+
|
| 696 |
+
sv_emb: list[torch.Tensor] = []
|
| 697 |
+
if model_version not in v3v4set:
|
| 698 |
+
refers = []
|
| 699 |
+
if inp_refs:
|
| 700 |
+
for path in inp_refs:
|
| 701 |
+
try: # 这里加上提取sv的逻辑,要么一堆sv一堆refer,要么单个sv单个refer
|
| 702 |
+
refer, audio_tensor = get_spepc(hps, path.name, dtype, infer_device, is_v2pro)
|
| 703 |
+
refers.append(refer)
|
| 704 |
+
if is_v2pro:
|
| 705 |
+
assert sv_cn_model
|
| 706 |
+
sv_emb.append(sv_cn_model.compute_embedding(audio_tensor))
|
| 707 |
+
except Exception as e:
|
| 708 |
+
print(e)
|
| 709 |
+
traceback.print_exc()
|
| 710 |
+
if len(refers) == 0:
|
| 711 |
+
refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, infer_device, is_v2pro)
|
| 712 |
+
refers = [refers]
|
| 713 |
+
if is_v2pro:
|
| 714 |
+
assert sv_cn_model
|
| 715 |
+
sv_emb = [sv_cn_model.compute_embedding(audio_tensor)]
|
| 716 |
+
if is_v2pro:
|
| 717 |
+
audio = vq_model.decode(
|
| 718 |
+
pred_semantic,
|
| 719 |
+
torch.LongTensor(phones2).to(infer_device).unsqueeze(0),
|
| 720 |
+
refers,
|
| 721 |
+
speed=speed,
|
| 722 |
+
sv_emb=sv_emb,
|
| 723 |
+
)[0][0] # type: ignore
|
| 724 |
+
else:
|
| 725 |
+
audio = vq_model.decode(
|
| 726 |
+
pred_semantic,
|
| 727 |
+
torch.LongTensor(phones2).to(infer_device).unsqueeze(0),
|
| 728 |
+
refers,
|
| 729 |
+
speed=speed,
|
| 730 |
+
)[0][0] # type: ignore
|
| 731 |
+
else:
|
| 732 |
+
raise RuntimeError("Unsupported model version")
|
| 733 |
+
if i_text == 0:
|
| 734 |
+
ttfb_time = ttime() - ttfb_time
|
| 735 |
+
max_audio = torch.abs(audio).max() # 简单防止16bit爆音
|
| 736 |
+
if max_audio > 1:
|
| 737 |
+
audio = audio / max_audio
|
| 738 |
+
audio_opt.append(audio)
|
| 739 |
+
audio_opt.append(zero_wav_torch) # zero_wav
|
| 740 |
+
t4 = ttime()
|
| 741 |
+
t.extend([t2 - t1, t3 - t2, t4 - t3])
|
| 742 |
+
t1 = ttime()
|
| 743 |
+
|
| 744 |
+
audio_opt_t = torch.cat(audio_opt, 0) # np.concatenate
|
| 745 |
+
opt_sr = 32000
|
| 746 |
+
audio_opt_n = audio_opt_t.cpu().numpy()
|
| 747 |
+
|
| 748 |
+
t0 = t[0]
|
| 749 |
+
t1 = sum(t[1::3])
|
| 750 |
+
t2 = sum(t[2::3])
|
| 751 |
+
t3 = sum(t[3::3])
|
| 752 |
+
|
| 753 |
+
infer_speed_avg = sum(infer_len) / sum(infer_time)
|
| 754 |
+
rtf_value = sum(t) / (audio_opt_n.__len__() / opt_sr)
|
| 755 |
+
|
| 756 |
+
console.print(f">> Time Stamps: {t0:.3f}\t{t1:.3f}\t{t2:.3f}\t{t3:.3f}")
|
| 757 |
+
console.print(f">> Infer Speed: {infer_speed_avg:.2f} Token/s")
|
| 758 |
+
console.print(f">> RTF: {rtf_value:.2f}")
|
| 759 |
+
if ttfb_time > 2:
|
| 760 |
+
console.print(f">> TTFB: {ttfb_time:.3f} s")
|
| 761 |
+
else:
|
| 762 |
+
console.print(f">> TTFB: {ttfb_time * 1000:.3f} ms")
|
| 763 |
+
|
| 764 |
+
gr.Info(f"{infer_speed_avg:.2f} Token/s", title="Infer Speed")
|
| 765 |
+
gr.Info(f"{rtf_value:.2f}", title="RTF")
|
| 766 |
+
|
| 767 |
+
if ttfb_time > 2:
|
| 768 |
+
gr.Info(f">> TTFB: {ttfb_time:.3f} s")
|
| 769 |
+
else:
|
| 770 |
+
gr.Info(f">> TTFB: {ttfb_time * 1000:.3f} ms")
|
| 771 |
+
|
| 772 |
+
if torch.cuda.is_available():
|
| 773 |
+
torch.cuda.empty_cache()
|
| 774 |
+
|
| 775 |
+
yield opt_sr, (audio_opt_n * 32767).astype(np.int16)
|
| 776 |
+
|
| 777 |
+
|
| 778 |
+
def split(todo_text):
|
| 779 |
+
todo_text = todo_text.replace("……", "。").replace("——", ",")
|
| 780 |
+
if todo_text[-1] not in splits:
|
| 781 |
+
todo_text += "。"
|
| 782 |
+
i_split_head = i_split_tail = 0
|
| 783 |
+
len_text = len(todo_text)
|
| 784 |
+
todo_texts = []
|
| 785 |
+
while 1:
|
| 786 |
+
if i_split_head >= len_text:
|
| 787 |
+
break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
|
| 788 |
+
if todo_text[i_split_head] in splits:
|
| 789 |
+
i_split_head += 1
|
| 790 |
+
todo_texts.append(todo_text[i_split_tail:i_split_head])
|
| 791 |
+
i_split_tail = i_split_head
|
| 792 |
+
else:
|
| 793 |
+
i_split_head += 1
|
| 794 |
+
return todo_texts
|
| 795 |
+
|
| 796 |
+
|
| 797 |
+
def cut1(inp):
|
| 798 |
+
inp = inp.strip("\n")
|
| 799 |
+
inps = split(inp)
|
| 800 |
+
split_idx: list[int | None] = list(range(0, len(inps) + 1, 4))
|
| 801 |
+
split_idx[-1] = None
|
| 802 |
+
if len(split_idx) > 1:
|
| 803 |
+
opts = []
|
| 804 |
+
for idx in range(len(split_idx) - 1):
|
| 805 |
+
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
|
| 806 |
+
else:
|
| 807 |
+
opts = [inp]
|
| 808 |
+
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
| 809 |
+
return "\n".join(opts)
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
def cut2(inp):
|
| 813 |
+
inp = inp.strip("\n")
|
| 814 |
+
inps = split(inp)
|
| 815 |
+
if len(inps) < 2:
|
| 816 |
+
return inp
|
| 817 |
+
opts = []
|
| 818 |
+
summ = 0
|
| 819 |
+
tmp_str = ""
|
| 820 |
+
for i in range(len(inps)):
|
| 821 |
+
summ += len(inps[i])
|
| 822 |
+
tmp_str += inps[i]
|
| 823 |
+
if summ > 50:
|
| 824 |
+
summ = 0
|
| 825 |
+
opts.append(tmp_str)
|
| 826 |
+
tmp_str = ""
|
| 827 |
+
if tmp_str != "":
|
| 828 |
+
opts.append(tmp_str)
|
| 829 |
+
if len(opts) > 1 and len(opts[-1]) < 50: # 如果最后一个太短了,和前一个合一起
|
| 830 |
+
opts[-2] = opts[-2] + opts[-1]
|
| 831 |
+
opts = opts[:-1]
|
| 832 |
+
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
| 833 |
+
return "\n".join(opts)
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
def cut3(inp):
|
| 837 |
+
inp = inp.strip("\n")
|
| 838 |
+
opts = inp.strip("。").split("。")
|
| 839 |
+
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
| 840 |
+
return "\n".join(opts)
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
def cut4(inp):
|
| 844 |
+
inp = inp.strip("\n")
|
| 845 |
+
opts = re.split(r"(?<!\d)\.(?!\d)", inp.strip("."))
|
| 846 |
+
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
| 847 |
+
return "\n".join(opts)
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
|
| 851 |
+
def cut5(inp):
|
| 852 |
+
inp = inp.strip("\n")
|
| 853 |
+
punds = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}
|
| 854 |
+
mergeitems = []
|
| 855 |
+
items = []
|
| 856 |
+
|
| 857 |
+
for i, char in enumerate(inp):
|
| 858 |
+
if char in punds:
|
| 859 |
+
if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
|
| 860 |
+
items.append(char)
|
| 861 |
+
else:
|
| 862 |
+
items.append(char)
|
| 863 |
+
mergeitems.append("".join(items))
|
| 864 |
+
items = []
|
| 865 |
+
else:
|
| 866 |
+
items.append(char)
|
| 867 |
+
|
| 868 |
+
if items:
|
| 869 |
+
mergeitems.append("".join(items))
|
| 870 |
+
|
| 871 |
+
opt = [item for item in mergeitems if not set(item).issubset(punds)]
|
| 872 |
+
return "\n".join(opt)
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
def process_text(texts):
|
| 876 |
+
_text = []
|
| 877 |
+
if all(text in [None, " ", "\n", ""] for text in texts):
|
| 878 |
+
raise ValueError(i18n("请输入有效文本"))
|
| 879 |
+
for text in texts:
|
| 880 |
+
if text in [None, " ", ""]:
|
| 881 |
+
pass
|
| 882 |
+
else:
|
| 883 |
+
_text.append(text)
|
| 884 |
+
return _text
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
def html_center(text, label="p"):
|
| 888 |
+
return f"""<div style="text-align: center; margin: 100; padding: 50;">
|
| 889 |
+
<{label} style="margin: 0; padding: 0;">{text}</{label}>
|
| 890 |
+
</div>"""
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
def html_left(text, label="p"):
|
| 894 |
+
return f"""<div style="text-align: left; margin: 0; padding: 0;">
|
| 895 |
+
<{label} style="margin: 0; padding: 0;">{text}</{label}>
|
| 896 |
+
</div>"""
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css) as app:
|
| 900 |
+
gr.HTML(
|
| 901 |
+
top_html.format(
|
| 902 |
+
i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
|
| 903 |
+
+ i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
| 904 |
+
),
|
| 905 |
+
elem_classes="markdown",
|
| 906 |
+
)
|
| 907 |
+
gr.Markdown(html_center(i18n("模型切换"), "h3"))
|
| 908 |
+
with gr.Row(equal_height=True):
|
| 909 |
+
with gr.Column(scale=2):
|
| 910 |
+
with gr.Row(equal_height=True):
|
| 911 |
+
GPT_dropdown = gr.Dropdown(
|
| 912 |
+
label=i18n("GPT模型列表"),
|
| 913 |
+
choices=GPT_names,
|
| 914 |
+
value=gpt_path,
|
| 915 |
+
interactive=True,
|
| 916 |
+
)
|
| 917 |
+
SoVITS_dropdown = gr.Dropdown(
|
| 918 |
+
label=i18n("SoVITS模型列表"),
|
| 919 |
+
choices=SoVITS_names,
|
| 920 |
+
value=sovits_path,
|
| 921 |
+
interactive=True,
|
| 922 |
+
)
|
| 923 |
+
with gr.Column(scale=1):
|
| 924 |
+
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary", scale=14)
|
| 925 |
+
refresh_button.click(fn=change_choices_i18n, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
| 926 |
+
gr.Markdown(html_center(i18n("*请上传并填写参考信息"), "h3"))
|
| 927 |
+
with gr.Row(equal_height=True):
|
| 928 |
+
with gr.Column(scale=2):
|
| 929 |
+
with gr.Row(equal_height=True):
|
| 930 |
+
with gr.Column(scale=1):
|
| 931 |
+
inp_ref = gr.Audio(
|
| 932 |
+
label=i18n("请上传3~10秒内参考音频,超过会报错!"),
|
| 933 |
+
type="filepath",
|
| 934 |
+
sources="upload",
|
| 935 |
+
scale=13,
|
| 936 |
+
editable=False,
|
| 937 |
+
waveform_options={"show_recording_waveform": False},
|
| 938 |
+
)
|
| 939 |
+
with gr.Column(scale=1):
|
| 940 |
+
gr.Markdown(
|
| 941 |
+
html_center(
|
| 942 |
+
i18n("使用无参考文本模式时建议使用微调的GPT")
|
| 943 |
+
+ "<br>"
|
| 944 |
+
+ i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")
|
| 945 |
+
)
|
| 946 |
+
)
|
| 947 |
+
ref_text_free = gr.Checkbox(
|
| 948 |
+
label=i18n("开启无参考文本模式"),
|
| 949 |
+
info=i18n("不填参考文本亦相当于开启") + ", " + i18n("v3暂不支持该模式,使用了会报错。"),
|
| 950 |
+
value=False,
|
| 951 |
+
interactive=True if model_version not in v3v4set else False,
|
| 952 |
+
show_label=True,
|
| 953 |
+
scale=1,
|
| 954 |
+
)
|
| 955 |
+
prompt_language = gr.Dropdown(
|
| 956 |
+
label="",
|
| 957 |
+
info=i18n("参考音频的语种"),
|
| 958 |
+
choices=list(dict_language.keys()),
|
| 959 |
+
value=i18n("中文"),
|
| 960 |
+
)
|
| 961 |
+
prompt_text = gr.Textbox(label="", info=i18n("参考音频的文本"), value="", lines=3, max_lines=3)
|
| 962 |
+
|
| 963 |
+
with gr.Column(scale=1):
|
| 964 |
+
inp_refs = (
|
| 965 |
+
gr.File(
|
| 966 |
+
label=i18n(
|
| 967 |
+
"可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"
|
| 968 |
+
),
|
| 969 |
+
file_count="multiple",
|
| 970 |
+
)
|
| 971 |
+
if model_version not in v3v4set
|
| 972 |
+
else gr.File(
|
| 973 |
+
label=i18n(
|
| 974 |
+
"可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"
|
| 975 |
+
),
|
| 976 |
+
file_count="multiple",
|
| 977 |
+
visible=False,
|
| 978 |
+
)
|
| 979 |
+
)
|
| 980 |
+
sample_steps = (
|
| 981 |
+
gr.Radio(
|
| 982 |
+
label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),
|
| 983 |
+
value=32 if model_version == "v3" else 8,
|
| 984 |
+
choices=[4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32],
|
| 985 |
+
visible=True,
|
| 986 |
+
)
|
| 987 |
+
if model_version in v3v4set
|
| 988 |
+
else gr.Radio(
|
| 989 |
+
label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),
|
| 990 |
+
choices=[4, 8, 16, 32, 64, 128] if model_version == "v3" else [4, 8, 16, 32],
|
| 991 |
+
visible=False,
|
| 992 |
+
value=32 if model_version == "v3" else 8,
|
| 993 |
+
)
|
| 994 |
+
)
|
| 995 |
+
if_sr_Checkbox = gr.Checkbox(
|
| 996 |
+
label=i18n("v3输出如果觉得闷可以试试开超分"),
|
| 997 |
+
value=False,
|
| 998 |
+
interactive=True,
|
| 999 |
+
show_label=True,
|
| 1000 |
+
visible=False if model_version != "v3" else True,
|
| 1001 |
+
)
|
| 1002 |
+
gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"), "h3"))
|
| 1003 |
+
with gr.Row(equal_height=True):
|
| 1004 |
+
with gr.Column(scale=2):
|
| 1005 |
+
text = gr.Textbox(label=i18n("需要合成的文本"), value="", lines=30, max_lines=40)
|
| 1006 |
+
with gr.Column(scale=1):
|
| 1007 |
+
text_language = gr.Dropdown(
|
| 1008 |
+
label=i18n("需要合成的语种") + i18n(".限制范围越小判别效果越好。"),
|
| 1009 |
+
choices=list(dict_language.keys()),
|
| 1010 |
+
value=i18n("中文"),
|
| 1011 |
+
scale=1,
|
| 1012 |
+
)
|
| 1013 |
+
how_to_cut = gr.Dropdown(
|
| 1014 |
+
label=i18n("怎么切"),
|
| 1015 |
+
choices=[
|
| 1016 |
+
i18n("不切"),
|
| 1017 |
+
i18n("凑四句一切"),
|
| 1018 |
+
i18n("凑50字一切"),
|
| 1019 |
+
i18n("按中文句号。切"),
|
| 1020 |
+
i18n("按英文句号.切"),
|
| 1021 |
+
i18n("按标点符号切"),
|
| 1022 |
+
],
|
| 1023 |
+
value=i18n("凑四句一切"),
|
| 1024 |
+
interactive=True,
|
| 1025 |
+
scale=1,
|
| 1026 |
+
)
|
| 1027 |
+
if_freeze = gr.Checkbox(
|
| 1028 |
+
label=i18n("是否直接对上次合成结果调整语速和音色"),
|
| 1029 |
+
value=False,
|
| 1030 |
+
interactive=True,
|
| 1031 |
+
show_label=True,
|
| 1032 |
+
scale=1,
|
| 1033 |
+
)
|
| 1034 |
+
with gr.Row(equal_height=True):
|
| 1035 |
+
speed = gr.Slider(
|
| 1036 |
+
minimum=0.6, maximum=1.65, step=0.05, label=i18n("语速"), value=1, interactive=True, scale=1
|
| 1037 |
+
)
|
| 1038 |
+
pause_second_slider = gr.Slider(
|
| 1039 |
+
minimum=0.1,
|
| 1040 |
+
maximum=0.5,
|
| 1041 |
+
step=0.01,
|
| 1042 |
+
label=i18n("句间停顿秒数"),
|
| 1043 |
+
value=0.3,
|
| 1044 |
+
interactive=True,
|
| 1045 |
+
scale=1,
|
| 1046 |
+
)
|
| 1047 |
+
gr.Markdown(html_center(i18n("GPT采样参数(不懂就用默认):")))
|
| 1048 |
+
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=15, interactive=True, scale=1)
|
| 1049 |
+
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True, scale=1)
|
| 1050 |
+
temperature = gr.Slider(
|
| 1051 |
+
minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True, scale=1
|
| 1052 |
+
)
|
| 1053 |
+
with gr.Row(equal_height=True):
|
| 1054 |
+
with gr.Column(scale=2):
|
| 1055 |
+
inference_button = gr.Button(value=i18n("合成语音"), variant="primary", size="lg")
|
| 1056 |
+
with gr.Column(scale=1):
|
| 1057 |
+
output = gr.Audio(
|
| 1058 |
+
label=i18n("输出的语音"),
|
| 1059 |
+
waveform_options={"show_recording_waveform": False},
|
| 1060 |
+
editable=False,
|
| 1061 |
+
)
|
| 1062 |
+
|
| 1063 |
+
inference_button.click(
|
| 1064 |
+
get_tts_wav,
|
| 1065 |
+
[
|
| 1066 |
+
inp_ref,
|
| 1067 |
+
prompt_text,
|
| 1068 |
+
prompt_language,
|
| 1069 |
+
text,
|
| 1070 |
+
text_language,
|
| 1071 |
+
how_to_cut,
|
| 1072 |
+
top_k,
|
| 1073 |
+
top_p,
|
| 1074 |
+
temperature,
|
| 1075 |
+
ref_text_free,
|
| 1076 |
+
speed,
|
| 1077 |
+
if_freeze,
|
| 1078 |
+
inp_refs,
|
| 1079 |
+
sample_steps,
|
| 1080 |
+
if_sr_Checkbox,
|
| 1081 |
+
pause_second_slider,
|
| 1082 |
+
],
|
| 1083 |
+
[output],
|
| 1084 |
+
)
|
| 1085 |
+
SoVITS_dropdown.change(
|
| 1086 |
+
change_sovits_weights,
|
| 1087 |
+
[SoVITS_dropdown, prompt_language, text_language],
|
| 1088 |
+
[
|
| 1089 |
+
prompt_text,
|
| 1090 |
+
prompt_language,
|
| 1091 |
+
text,
|
| 1092 |
+
text_language,
|
| 1093 |
+
sample_steps,
|
| 1094 |
+
inp_refs,
|
| 1095 |
+
ref_text_free,
|
| 1096 |
+
if_sr_Checkbox,
|
| 1097 |
+
inference_button,
|
| 1098 |
+
],
|
| 1099 |
+
)
|
| 1100 |
+
GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
|
| 1101 |
+
|
| 1102 |
+
|
| 1103 |
+
if __name__ == "__main__":
|
| 1104 |
+
app.queue(api_open=False, default_concurrency_limit=1, max_size=1024).launch()
|
GPT_SoVITS/module/attentions.py
ADDED
|
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from torch.nn.utils import remove_weight_norm
|
| 7 |
+
from torch.nn.utils.parametrizations import weight_norm
|
| 8 |
+
|
| 9 |
+
from . import commons
|
| 10 |
+
from .modules import LayerNorm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Encoder(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
hidden_channels,
|
| 17 |
+
filter_channels,
|
| 18 |
+
n_heads,
|
| 19 |
+
n_layers,
|
| 20 |
+
kernel_size=1,
|
| 21 |
+
p_dropout=0.0,
|
| 22 |
+
window_size=4,
|
| 23 |
+
isflow=False,
|
| 24 |
+
**kwargs,
|
| 25 |
+
):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.hidden_channels = hidden_channels
|
| 28 |
+
self.filter_channels = filter_channels
|
| 29 |
+
self.n_heads = n_heads
|
| 30 |
+
self.n_layers = n_layers
|
| 31 |
+
self.kernel_size = kernel_size
|
| 32 |
+
self.p_dropout = p_dropout
|
| 33 |
+
self.window_size = window_size
|
| 34 |
+
|
| 35 |
+
self.drop = nn.Dropout(p_dropout)
|
| 36 |
+
self.attn_layers = nn.ModuleList()
|
| 37 |
+
self.norm_layers_1 = nn.ModuleList()
|
| 38 |
+
self.ffn_layers = nn.ModuleList()
|
| 39 |
+
self.norm_layers_2 = nn.ModuleList()
|
| 40 |
+
for i in range(self.n_layers):
|
| 41 |
+
self.attn_layers.append(
|
| 42 |
+
MultiHeadAttention(
|
| 43 |
+
hidden_channels,
|
| 44 |
+
hidden_channels,
|
| 45 |
+
n_heads,
|
| 46 |
+
p_dropout=p_dropout,
|
| 47 |
+
window_size=window_size,
|
| 48 |
+
)
|
| 49 |
+
)
|
| 50 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
| 51 |
+
self.ffn_layers.append(
|
| 52 |
+
FFN(
|
| 53 |
+
hidden_channels,
|
| 54 |
+
hidden_channels,
|
| 55 |
+
filter_channels,
|
| 56 |
+
kernel_size,
|
| 57 |
+
p_dropout=p_dropout,
|
| 58 |
+
)
|
| 59 |
+
)
|
| 60 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
| 61 |
+
if isflow:
|
| 62 |
+
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
|
| 63 |
+
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
|
| 64 |
+
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
|
| 65 |
+
self.gin_channels = kwargs["gin_channels"]
|
| 66 |
+
|
| 67 |
+
def forward(self, x, x_mask, g=None):
|
| 68 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
| 69 |
+
x = x * x_mask
|
| 70 |
+
if g is not None:
|
| 71 |
+
g = self.cond_layer(g)
|
| 72 |
+
|
| 73 |
+
for i in range(self.n_layers):
|
| 74 |
+
if g is not None:
|
| 75 |
+
x = self.cond_pre(x)
|
| 76 |
+
cond_offset = i * 2 * self.hidden_channels
|
| 77 |
+
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
| 78 |
+
x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
|
| 79 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
| 80 |
+
y = self.drop(y)
|
| 81 |
+
x = self.norm_layers_1[i](x + y)
|
| 82 |
+
|
| 83 |
+
y = self.ffn_layers[i](x, x_mask)
|
| 84 |
+
y = self.drop(y)
|
| 85 |
+
x = self.norm_layers_2[i](x + y)
|
| 86 |
+
x = x * x_mask
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class Decoder(nn.Module):
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
hidden_channels,
|
| 94 |
+
filter_channels,
|
| 95 |
+
n_heads,
|
| 96 |
+
n_layers,
|
| 97 |
+
kernel_size=1,
|
| 98 |
+
p_dropout=0.0,
|
| 99 |
+
proximal_bias=False,
|
| 100 |
+
proximal_init=True,
|
| 101 |
+
**kwargs,
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.hidden_channels = hidden_channels
|
| 105 |
+
self.filter_channels = filter_channels
|
| 106 |
+
self.n_heads = n_heads
|
| 107 |
+
self.n_layers = n_layers
|
| 108 |
+
self.kernel_size = kernel_size
|
| 109 |
+
self.p_dropout = p_dropout
|
| 110 |
+
self.proximal_bias = proximal_bias
|
| 111 |
+
self.proximal_init = proximal_init
|
| 112 |
+
|
| 113 |
+
self.drop = nn.Dropout(p_dropout)
|
| 114 |
+
self.self_attn_layers = nn.ModuleList()
|
| 115 |
+
self.norm_layers_0 = nn.ModuleList()
|
| 116 |
+
self.encdec_attn_layers = nn.ModuleList()
|
| 117 |
+
self.norm_layers_1 = nn.ModuleList()
|
| 118 |
+
self.ffn_layers = nn.ModuleList()
|
| 119 |
+
self.norm_layers_2 = nn.ModuleList()
|
| 120 |
+
for i in range(self.n_layers):
|
| 121 |
+
self.self_attn_layers.append(
|
| 122 |
+
MultiHeadAttention(
|
| 123 |
+
hidden_channels,
|
| 124 |
+
hidden_channels,
|
| 125 |
+
n_heads,
|
| 126 |
+
p_dropout=p_dropout,
|
| 127 |
+
proximal_bias=proximal_bias,
|
| 128 |
+
proximal_init=proximal_init,
|
| 129 |
+
)
|
| 130 |
+
)
|
| 131 |
+
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
| 132 |
+
self.encdec_attn_layers.append(
|
| 133 |
+
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)
|
| 134 |
+
)
|
| 135 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
| 136 |
+
self.ffn_layers.append(
|
| 137 |
+
FFN(
|
| 138 |
+
hidden_channels,
|
| 139 |
+
hidden_channels,
|
| 140 |
+
filter_channels,
|
| 141 |
+
kernel_size,
|
| 142 |
+
p_dropout=p_dropout,
|
| 143 |
+
causal=True,
|
| 144 |
+
)
|
| 145 |
+
)
|
| 146 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
| 147 |
+
|
| 148 |
+
def forward(self, x, x_mask, h, h_mask):
|
| 149 |
+
"""
|
| 150 |
+
x: decoder input
|
| 151 |
+
h: encoder output
|
| 152 |
+
"""
|
| 153 |
+
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
|
| 154 |
+
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
| 155 |
+
x = x * x_mask
|
| 156 |
+
for i in range(self.n_layers):
|
| 157 |
+
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
| 158 |
+
y = self.drop(y)
|
| 159 |
+
x = self.norm_layers_0[i](x + y)
|
| 160 |
+
|
| 161 |
+
y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
|
| 162 |
+
y = self.drop(y)
|
| 163 |
+
x = self.norm_layers_1[i](x + y)
|
| 164 |
+
|
| 165 |
+
y = self.ffn_layers[i](x, x_mask)
|
| 166 |
+
y = self.drop(y)
|
| 167 |
+
x = self.norm_layers_2[i](x + y)
|
| 168 |
+
x = x * x_mask
|
| 169 |
+
return x
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class MultiHeadAttention(nn.Module):
|
| 173 |
+
def __init__(
|
| 174 |
+
self,
|
| 175 |
+
channels,
|
| 176 |
+
out_channels,
|
| 177 |
+
n_heads,
|
| 178 |
+
p_dropout=0.0,
|
| 179 |
+
window_size=None,
|
| 180 |
+
heads_share=True,
|
| 181 |
+
block_length=None,
|
| 182 |
+
proximal_bias=False,
|
| 183 |
+
proximal_init=False,
|
| 184 |
+
):
|
| 185 |
+
super().__init__()
|
| 186 |
+
assert channels % n_heads == 0
|
| 187 |
+
|
| 188 |
+
self.channels = channels
|
| 189 |
+
self.out_channels = out_channels
|
| 190 |
+
self.n_heads = n_heads
|
| 191 |
+
self.p_dropout = p_dropout
|
| 192 |
+
self.window_size = window_size
|
| 193 |
+
self.heads_share = heads_share
|
| 194 |
+
self.block_length = block_length
|
| 195 |
+
self.proximal_bias = proximal_bias
|
| 196 |
+
self.proximal_init = proximal_init
|
| 197 |
+
self.attn = None
|
| 198 |
+
|
| 199 |
+
self.k_channels = channels // n_heads
|
| 200 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
| 201 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
| 202 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
| 203 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
| 204 |
+
self.drop = nn.Dropout(p_dropout)
|
| 205 |
+
|
| 206 |
+
if window_size is not None:
|
| 207 |
+
n_heads_rel = 1 if heads_share else n_heads
|
| 208 |
+
rel_stddev = self.k_channels**-0.5
|
| 209 |
+
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
| 210 |
+
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
| 211 |
+
|
| 212 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
| 213 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
| 214 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
| 215 |
+
if proximal_init:
|
| 216 |
+
with torch.no_grad():
|
| 217 |
+
self.conv_k.weight.copy_(self.conv_q.weight)
|
| 218 |
+
self.conv_k.bias.copy_(self.conv_q.bias)
|
| 219 |
+
|
| 220 |
+
def forward(self, x, c, attn_mask=None):
|
| 221 |
+
q = self.conv_q(x)
|
| 222 |
+
k = self.conv_k(c)
|
| 223 |
+
v = self.conv_v(c)
|
| 224 |
+
|
| 225 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
| 226 |
+
|
| 227 |
+
x = self.conv_o(x)
|
| 228 |
+
return x
|
| 229 |
+
|
| 230 |
+
def attention(self, query, key, value, mask=None):
|
| 231 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
| 232 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
| 233 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
| 234 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
| 235 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
| 236 |
+
|
| 237 |
+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
| 238 |
+
if self.window_size is not None:
|
| 239 |
+
assert t_s == t_t, "Relative attention is only available for self-attention."
|
| 240 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
| 241 |
+
rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
|
| 242 |
+
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
| 243 |
+
scores = scores + scores_local
|
| 244 |
+
if self.proximal_bias:
|
| 245 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
| 246 |
+
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
| 247 |
+
if mask is not None:
|
| 248 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
| 249 |
+
if self.block_length is not None:
|
| 250 |
+
assert t_s == t_t, "Local attention is only available for self-attention."
|
| 251 |
+
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
|
| 252 |
+
scores = scores.masked_fill(block_mask == 0, -1e4)
|
| 253 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
| 254 |
+
p_attn = self.drop(p_attn)
|
| 255 |
+
output = torch.matmul(p_attn, value)
|
| 256 |
+
if self.window_size is not None:
|
| 257 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
| 258 |
+
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
| 259 |
+
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
| 260 |
+
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
| 261 |
+
return output, p_attn
|
| 262 |
+
|
| 263 |
+
def _matmul_with_relative_values(self, x, y):
|
| 264 |
+
"""
|
| 265 |
+
x: [b, h, l, m]
|
| 266 |
+
y: [h or 1, m, d]
|
| 267 |
+
ret: [b, h, l, d]
|
| 268 |
+
"""
|
| 269 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
| 270 |
+
return ret
|
| 271 |
+
|
| 272 |
+
def _matmul_with_relative_keys(self, x, y):
|
| 273 |
+
"""
|
| 274 |
+
x: [b, h, l, d]
|
| 275 |
+
y: [h or 1, m, d]
|
| 276 |
+
ret: [b, h, l, m]
|
| 277 |
+
"""
|
| 278 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
| 279 |
+
return ret
|
| 280 |
+
|
| 281 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
| 282 |
+
max_relative_position = 2 * self.window_size + 1
|
| 283 |
+
# Pad first before slice to avoid using cond ops.
|
| 284 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
| 285 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
| 286 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
| 287 |
+
if pad_length > 0:
|
| 288 |
+
padded_relative_embeddings = F.pad(
|
| 289 |
+
relative_embeddings,
|
| 290 |
+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
| 291 |
+
)
|
| 292 |
+
else:
|
| 293 |
+
padded_relative_embeddings = relative_embeddings
|
| 294 |
+
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
| 295 |
+
return used_relative_embeddings
|
| 296 |
+
|
| 297 |
+
def _relative_position_to_absolute_position(self, x):
|
| 298 |
+
"""
|
| 299 |
+
x: [b, h, l, 2*l-1]
|
| 300 |
+
ret: [b, h, l, l]
|
| 301 |
+
"""
|
| 302 |
+
batch, heads, length, _ = x.size()
|
| 303 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
| 304 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
| 305 |
+
|
| 306 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
| 307 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
| 308 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
| 309 |
+
|
| 310 |
+
# Reshape and slice out the padded elements.
|
| 311 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
|
| 312 |
+
return x_final
|
| 313 |
+
|
| 314 |
+
def _absolute_position_to_relative_position(self, x):
|
| 315 |
+
"""
|
| 316 |
+
x: [b, h, l, l]
|
| 317 |
+
ret: [b, h, l, 2*l-1]
|
| 318 |
+
"""
|
| 319 |
+
batch, heads, length, _ = x.size()
|
| 320 |
+
# padd along column
|
| 321 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
| 322 |
+
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
| 323 |
+
# add 0's in the beginning that will skew the elements after reshape
|
| 324 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
| 325 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
| 326 |
+
return x_final
|
| 327 |
+
|
| 328 |
+
def _attention_bias_proximal(self, length):
|
| 329 |
+
"""Bias for self-attention to encourage attention to close positions.
|
| 330 |
+
Args:
|
| 331 |
+
length: an integer scalar.
|
| 332 |
+
Returns:
|
| 333 |
+
a Tensor with shape [1, 1, length, length]
|
| 334 |
+
"""
|
| 335 |
+
r = torch.arange(length, dtype=torch.float32)
|
| 336 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
| 337 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class FFN(nn.Module):
|
| 341 |
+
def __init__(
|
| 342 |
+
self,
|
| 343 |
+
in_channels,
|
| 344 |
+
out_channels,
|
| 345 |
+
filter_channels,
|
| 346 |
+
kernel_size,
|
| 347 |
+
p_dropout=0.0,
|
| 348 |
+
activation=None,
|
| 349 |
+
causal=False,
|
| 350 |
+
):
|
| 351 |
+
super().__init__()
|
| 352 |
+
self.in_channels = in_channels
|
| 353 |
+
self.out_channels = out_channels
|
| 354 |
+
self.filter_channels = filter_channels
|
| 355 |
+
self.kernel_size = kernel_size
|
| 356 |
+
self.p_dropout = p_dropout
|
| 357 |
+
self.activation = activation
|
| 358 |
+
self.causal = causal
|
| 359 |
+
|
| 360 |
+
if causal:
|
| 361 |
+
self.padding = self._causal_padding
|
| 362 |
+
else:
|
| 363 |
+
self.padding = self._same_padding
|
| 364 |
+
|
| 365 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
| 366 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
| 367 |
+
self.drop = nn.Dropout(p_dropout)
|
| 368 |
+
|
| 369 |
+
def forward(self, x, x_mask):
|
| 370 |
+
x = self.conv_1(self.padding(x * x_mask))
|
| 371 |
+
if self.activation == "gelu":
|
| 372 |
+
x = x * torch.sigmoid(1.702 * x)
|
| 373 |
+
else:
|
| 374 |
+
x = torch.relu(x)
|
| 375 |
+
x = self.drop(x)
|
| 376 |
+
x = self.conv_2(self.padding(x * x_mask))
|
| 377 |
+
return x * x_mask
|
| 378 |
+
|
| 379 |
+
def _causal_padding(self, x):
|
| 380 |
+
if self.kernel_size == 1:
|
| 381 |
+
return x
|
| 382 |
+
pad_l = self.kernel_size - 1
|
| 383 |
+
pad_r = 0
|
| 384 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
| 385 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
| 386 |
+
return x
|
| 387 |
+
|
| 388 |
+
def _same_padding(self, x):
|
| 389 |
+
if self.kernel_size == 1:
|
| 390 |
+
return x
|
| 391 |
+
pad_l = (self.kernel_size - 1) // 2
|
| 392 |
+
pad_r = self.kernel_size // 2
|
| 393 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
| 394 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
| 395 |
+
return x
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
class Depthwise_Separable_Conv1D(nn.Module):
|
| 399 |
+
def __init__(
|
| 400 |
+
self,
|
| 401 |
+
in_channels,
|
| 402 |
+
out_channels,
|
| 403 |
+
kernel_size,
|
| 404 |
+
stride=1,
|
| 405 |
+
padding=0,
|
| 406 |
+
dilation=1,
|
| 407 |
+
bias=True,
|
| 408 |
+
padding_mode="zeros", # TODO: refine this type
|
| 409 |
+
device=None,
|
| 410 |
+
dtype=None,
|
| 411 |
+
):
|
| 412 |
+
super().__init__()
|
| 413 |
+
self.depth_conv = nn.Conv1d(
|
| 414 |
+
in_channels=in_channels,
|
| 415 |
+
out_channels=in_channels,
|
| 416 |
+
kernel_size=kernel_size,
|
| 417 |
+
groups=in_channels,
|
| 418 |
+
stride=stride,
|
| 419 |
+
padding=padding,
|
| 420 |
+
dilation=dilation,
|
| 421 |
+
bias=bias,
|
| 422 |
+
padding_mode=padding_mode,
|
| 423 |
+
device=device,
|
| 424 |
+
dtype=dtype,
|
| 425 |
+
)
|
| 426 |
+
self.point_conv = nn.Conv1d(
|
| 427 |
+
in_channels=in_channels,
|
| 428 |
+
out_channels=out_channels,
|
| 429 |
+
kernel_size=1,
|
| 430 |
+
bias=bias,
|
| 431 |
+
device=device,
|
| 432 |
+
dtype=dtype,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
def forward(self, input):
|
| 436 |
+
return self.point_conv(self.depth_conv(input))
|
| 437 |
+
|
| 438 |
+
def weight_norm(self):
|
| 439 |
+
self.depth_conv = weight_norm(self.depth_conv, name="weight")
|
| 440 |
+
self.point_conv = weight_norm(self.point_conv, name="weight")
|
| 441 |
+
|
| 442 |
+
def remove_weight_norm(self):
|
| 443 |
+
self.depth_conv = remove_weight_norm(self.depth_conv, name="weight")
|
| 444 |
+
self.point_conv = remove_weight_norm(self.point_conv, name="weight")
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
class Depthwise_Separable_TransposeConv1D(nn.Module):
|
| 448 |
+
def __init__(
|
| 449 |
+
self,
|
| 450 |
+
in_channels,
|
| 451 |
+
out_channels,
|
| 452 |
+
kernel_size,
|
| 453 |
+
stride=1,
|
| 454 |
+
padding=0,
|
| 455 |
+
output_padding=0,
|
| 456 |
+
bias=True,
|
| 457 |
+
dilation=1,
|
| 458 |
+
padding_mode="zeros", # TODO: refine this type
|
| 459 |
+
device=None,
|
| 460 |
+
dtype=None,
|
| 461 |
+
):
|
| 462 |
+
super().__init__()
|
| 463 |
+
self.depth_conv = nn.ConvTranspose1d(
|
| 464 |
+
in_channels=in_channels,
|
| 465 |
+
out_channels=in_channels,
|
| 466 |
+
kernel_size=kernel_size,
|
| 467 |
+
groups=in_channels,
|
| 468 |
+
stride=stride,
|
| 469 |
+
output_padding=output_padding,
|
| 470 |
+
padding=padding,
|
| 471 |
+
dilation=dilation,
|
| 472 |
+
bias=bias,
|
| 473 |
+
padding_mode=padding_mode,
|
| 474 |
+
device=device,
|
| 475 |
+
dtype=dtype,
|
| 476 |
+
)
|
| 477 |
+
self.point_conv = nn.Conv1d(
|
| 478 |
+
in_channels=in_channels,
|
| 479 |
+
out_channels=out_channels,
|
| 480 |
+
kernel_size=1,
|
| 481 |
+
bias=bias,
|
| 482 |
+
device=device,
|
| 483 |
+
dtype=dtype,
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
def forward(self, input):
|
| 487 |
+
return self.point_conv(self.depth_conv(input))
|
| 488 |
+
|
| 489 |
+
def weight_norm(self):
|
| 490 |
+
self.depth_conv = weight_norm(self.depth_conv, name="weight")
|
| 491 |
+
self.point_conv = weight_norm(self.point_conv, name="weight")
|
| 492 |
+
|
| 493 |
+
def remove_weight_norm(self):
|
| 494 |
+
remove_weight_norm(self.depth_conv, name="weight")
|
| 495 |
+
remove_weight_norm(self.point_conv, name="weight")
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def weight_norm_modules(module, name="weight", dim=0):
|
| 499 |
+
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
|
| 500 |
+
module.weight_norm()
|
| 501 |
+
return module
|
| 502 |
+
else:
|
| 503 |
+
return weight_norm(module, name, dim)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def remove_weight_norm_modules(module, name="weight"):
|
| 507 |
+
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
|
| 508 |
+
module.remove_weight_norm()
|
| 509 |
+
else:
|
| 510 |
+
remove_weight_norm(module, name)
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
class FFT(nn.Module):
|
| 514 |
+
def __init__(
|
| 515 |
+
self,
|
| 516 |
+
hidden_channels,
|
| 517 |
+
filter_channels,
|
| 518 |
+
n_heads,
|
| 519 |
+
n_layers=1,
|
| 520 |
+
kernel_size=1,
|
| 521 |
+
p_dropout=0.0,
|
| 522 |
+
proximal_bias=False,
|
| 523 |
+
proximal_init=True,
|
| 524 |
+
isflow=False,
|
| 525 |
+
**kwargs,
|
| 526 |
+
):
|
| 527 |
+
super().__init__()
|
| 528 |
+
self.hidden_channels = hidden_channels
|
| 529 |
+
self.filter_channels = filter_channels
|
| 530 |
+
self.n_heads = n_heads
|
| 531 |
+
self.n_layers = n_layers
|
| 532 |
+
self.kernel_size = kernel_size
|
| 533 |
+
self.p_dropout = p_dropout
|
| 534 |
+
self.proximal_bias = proximal_bias
|
| 535 |
+
self.proximal_init = proximal_init
|
| 536 |
+
if isflow:
|
| 537 |
+
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
|
| 538 |
+
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
|
| 539 |
+
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
|
| 540 |
+
self.gin_channels = kwargs["gin_channels"]
|
| 541 |
+
self.drop = nn.Dropout(p_dropout)
|
| 542 |
+
self.self_attn_layers = nn.ModuleList()
|
| 543 |
+
self.norm_layers_0 = nn.ModuleList()
|
| 544 |
+
self.ffn_layers = nn.ModuleList()
|
| 545 |
+
self.norm_layers_1 = nn.ModuleList()
|
| 546 |
+
for i in range(self.n_layers):
|
| 547 |
+
self.self_attn_layers.append(
|
| 548 |
+
MultiHeadAttention(
|
| 549 |
+
hidden_channels,
|
| 550 |
+
hidden_channels,
|
| 551 |
+
n_heads,
|
| 552 |
+
p_dropout=p_dropout,
|
| 553 |
+
proximal_bias=proximal_bias,
|
| 554 |
+
proximal_init=proximal_init,
|
| 555 |
+
)
|
| 556 |
+
)
|
| 557 |
+
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
| 558 |
+
self.ffn_layers.append(
|
| 559 |
+
FFN(
|
| 560 |
+
hidden_channels,
|
| 561 |
+
hidden_channels,
|
| 562 |
+
filter_channels,
|
| 563 |
+
kernel_size,
|
| 564 |
+
p_dropout=p_dropout,
|
| 565 |
+
causal=True,
|
| 566 |
+
)
|
| 567 |
+
)
|
| 568 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
| 569 |
+
|
| 570 |
+
def forward(self, x, x_mask, g=None):
|
| 571 |
+
"""
|
| 572 |
+
x: decoder input
|
| 573 |
+
h: encoder output
|
| 574 |
+
"""
|
| 575 |
+
if g is not None:
|
| 576 |
+
g = self.cond_layer(g)
|
| 577 |
+
|
| 578 |
+
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
|
| 579 |
+
x = x * x_mask
|
| 580 |
+
for i in range(self.n_layers):
|
| 581 |
+
if g is not None:
|
| 582 |
+
x = self.cond_pre(x)
|
| 583 |
+
cond_offset = i * 2 * self.hidden_channels
|
| 584 |
+
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
| 585 |
+
x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
|
| 586 |
+
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
| 587 |
+
y = self.drop(y)
|
| 588 |
+
x = self.norm_layers_0[i](x + y)
|
| 589 |
+
|
| 590 |
+
y = self.ffn_layers[i](x, x_mask)
|
| 591 |
+
y = self.drop(y)
|
| 592 |
+
x = self.norm_layers_1[i](x + y)
|
| 593 |
+
x = x * x_mask
|
| 594 |
+
return x
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
class TransformerCouplingLayer(nn.Module):
|
| 598 |
+
def __init__(
|
| 599 |
+
self,
|
| 600 |
+
channels,
|
| 601 |
+
hidden_channels,
|
| 602 |
+
kernel_size,
|
| 603 |
+
n_layers,
|
| 604 |
+
n_heads,
|
| 605 |
+
p_dropout=0,
|
| 606 |
+
filter_channels=0,
|
| 607 |
+
mean_only=False,
|
| 608 |
+
wn_sharing_parameter=None,
|
| 609 |
+
gin_channels=0,
|
| 610 |
+
):
|
| 611 |
+
assert channels % 2 == 0, "channels should be divisible by 2"
|
| 612 |
+
super().__init__()
|
| 613 |
+
self.channels = channels
|
| 614 |
+
self.hidden_channels = hidden_channels
|
| 615 |
+
self.kernel_size = kernel_size
|
| 616 |
+
self.n_layers = n_layers
|
| 617 |
+
self.half_channels = channels // 2
|
| 618 |
+
self.mean_only = mean_only
|
| 619 |
+
|
| 620 |
+
self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
|
| 621 |
+
self.enc = (
|
| 622 |
+
Encoder(
|
| 623 |
+
hidden_channels,
|
| 624 |
+
filter_channels,
|
| 625 |
+
n_heads,
|
| 626 |
+
n_layers,
|
| 627 |
+
kernel_size,
|
| 628 |
+
p_dropout,
|
| 629 |
+
isflow=True,
|
| 630 |
+
gin_channels=gin_channels,
|
| 631 |
+
)
|
| 632 |
+
if wn_sharing_parameter is None
|
| 633 |
+
else wn_sharing_parameter
|
| 634 |
+
)
|
| 635 |
+
self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
| 636 |
+
self.post.weight.data.zero_()
|
| 637 |
+
self.post.bias.data.zero_()
|
| 638 |
+
|
| 639 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
| 640 |
+
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
| 641 |
+
h = self.pre(x0) * x_mask
|
| 642 |
+
h = self.enc(h, x_mask, g=g)
|
| 643 |
+
stats = self.post(h) * x_mask
|
| 644 |
+
if not self.mean_only:
|
| 645 |
+
m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
| 646 |
+
else:
|
| 647 |
+
m = stats
|
| 648 |
+
logs = torch.zeros_like(m)
|
| 649 |
+
|
| 650 |
+
if not reverse:
|
| 651 |
+
x1 = m + x1 * torch.exp(logs) * x_mask
|
| 652 |
+
x = torch.cat([x0, x1], 1)
|
| 653 |
+
logdet = torch.sum(logs, [1, 2])
|
| 654 |
+
return x, logdet
|
| 655 |
+
else:
|
| 656 |
+
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
| 657 |
+
x = torch.cat([x0, x1], 1)
|
| 658 |
+
return x
|
GPT_SoVITS/module/attentions_onnx.py
ADDED
|
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
from . import commons
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class LayerNorm(nn.Module):
|
| 12 |
+
def __init__(self, channels, eps=1e-5):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.channels = channels
|
| 15 |
+
self.eps = eps
|
| 16 |
+
|
| 17 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
| 18 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
x = x.transpose(1, -1)
|
| 22 |
+
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
|
| 23 |
+
return x.transpose(1, -1)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@torch.jit.script
|
| 27 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
| 28 |
+
n_channels_int = n_channels[0]
|
| 29 |
+
in_act = input_a + input_b
|
| 30 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
| 31 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
| 32 |
+
acts = t_act * s_act
|
| 33 |
+
return acts
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Encoder(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
hidden_channels,
|
| 40 |
+
filter_channels,
|
| 41 |
+
n_heads,
|
| 42 |
+
n_layers,
|
| 43 |
+
kernel_size=1,
|
| 44 |
+
p_dropout=0.0,
|
| 45 |
+
window_size=4,
|
| 46 |
+
isflow=True,
|
| 47 |
+
**kwargs,
|
| 48 |
+
):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.hidden_channels = hidden_channels
|
| 51 |
+
self.filter_channels = filter_channels
|
| 52 |
+
self.n_heads = n_heads
|
| 53 |
+
self.n_layers = n_layers
|
| 54 |
+
self.kernel_size = kernel_size
|
| 55 |
+
self.p_dropout = p_dropout
|
| 56 |
+
self.window_size = window_size
|
| 57 |
+
# if isflow:
|
| 58 |
+
# cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
|
| 59 |
+
# self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
|
| 60 |
+
# self.cond_layer = weight_norm(cond_layer, name='weight')
|
| 61 |
+
# self.gin_channels = 256
|
| 62 |
+
self.cond_layer_idx = self.n_layers
|
| 63 |
+
self.spk_emb_linear = nn.Linear(256, self.hidden_channels)
|
| 64 |
+
if "gin_channels" in kwargs:
|
| 65 |
+
self.gin_channels = kwargs["gin_channels"]
|
| 66 |
+
if self.gin_channels != 0:
|
| 67 |
+
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
|
| 68 |
+
# vits2 says 3rd block, so idx is 2 by default
|
| 69 |
+
self.cond_layer_idx = kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
|
| 70 |
+
logging.debug(self.gin_channels, self.cond_layer_idx)
|
| 71 |
+
assert self.cond_layer_idx < self.n_layers, "cond_layer_idx should be less than n_layers"
|
| 72 |
+
self.drop = nn.Dropout(p_dropout)
|
| 73 |
+
self.attn_layers = nn.ModuleList()
|
| 74 |
+
self.norm_layers_1 = nn.ModuleList()
|
| 75 |
+
self.ffn_layers = nn.ModuleList()
|
| 76 |
+
self.norm_layers_2 = nn.ModuleList()
|
| 77 |
+
for i in range(self.n_layers):
|
| 78 |
+
self.attn_layers.append(
|
| 79 |
+
MultiHeadAttention(
|
| 80 |
+
hidden_channels,
|
| 81 |
+
hidden_channels,
|
| 82 |
+
n_heads,
|
| 83 |
+
p_dropout=p_dropout,
|
| 84 |
+
window_size=window_size,
|
| 85 |
+
)
|
| 86 |
+
)
|
| 87 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
| 88 |
+
self.ffn_layers.append(
|
| 89 |
+
FFN(
|
| 90 |
+
hidden_channels,
|
| 91 |
+
hidden_channels,
|
| 92 |
+
filter_channels,
|
| 93 |
+
kernel_size,
|
| 94 |
+
p_dropout=p_dropout,
|
| 95 |
+
)
|
| 96 |
+
)
|
| 97 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
| 98 |
+
|
| 99 |
+
# def forward(self, x, x_mask, g=None):
|
| 100 |
+
# attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
| 101 |
+
# x = x * x_mask
|
| 102 |
+
# for i in range(self.n_layers):
|
| 103 |
+
# if i == self.cond_layer_idx and g is not None:
|
| 104 |
+
# g = self.spk_emb_linear(g.transpose(1, 2))
|
| 105 |
+
# g = g.transpose(1, 2)
|
| 106 |
+
# x = x + g
|
| 107 |
+
# x = x * x_mask
|
| 108 |
+
# y = self.attn_layers[i](x, x, attn_mask)
|
| 109 |
+
# y = self.drop(y)
|
| 110 |
+
# x = self.norm_layers_1[i](x + y)
|
| 111 |
+
|
| 112 |
+
# y = self.ffn_layers[i](x, x_mask)
|
| 113 |
+
# y = self.drop(y)
|
| 114 |
+
# x = self.norm_layers_2[i](x + y)
|
| 115 |
+
# x = x * x_mask
|
| 116 |
+
# return x
|
| 117 |
+
|
| 118 |
+
def forward(self, x, x_mask):
|
| 119 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
| 120 |
+
x = x * x_mask
|
| 121 |
+
for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zip(
|
| 122 |
+
self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
|
| 123 |
+
):
|
| 124 |
+
y = attn_layers(x, x, attn_mask)
|
| 125 |
+
y = self.drop(y)
|
| 126 |
+
x = norm_layers_1(x + y)
|
| 127 |
+
|
| 128 |
+
y = ffn_layers(x, x_mask)
|
| 129 |
+
y = self.drop(y)
|
| 130 |
+
x = norm_layers_2(x + y)
|
| 131 |
+
x = x * x_mask
|
| 132 |
+
return x
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class MultiHeadAttention(nn.Module):
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
channels,
|
| 139 |
+
out_channels,
|
| 140 |
+
n_heads,
|
| 141 |
+
p_dropout=0.0,
|
| 142 |
+
window_size=None,
|
| 143 |
+
heads_share=True,
|
| 144 |
+
block_length=None,
|
| 145 |
+
proximal_bias=False,
|
| 146 |
+
proximal_init=False,
|
| 147 |
+
):
|
| 148 |
+
super().__init__()
|
| 149 |
+
assert channels % n_heads == 0
|
| 150 |
+
|
| 151 |
+
self.channels = channels
|
| 152 |
+
self.out_channels = out_channels
|
| 153 |
+
self.n_heads = n_heads
|
| 154 |
+
self.p_dropout = p_dropout
|
| 155 |
+
self.window_size = window_size
|
| 156 |
+
self.heads_share = heads_share
|
| 157 |
+
self.block_length = block_length
|
| 158 |
+
self.proximal_bias = proximal_bias
|
| 159 |
+
self.proximal_init = proximal_init
|
| 160 |
+
self.attn = None
|
| 161 |
+
|
| 162 |
+
self.k_channels = channels // n_heads
|
| 163 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
| 164 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
| 165 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
| 166 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
| 167 |
+
self.drop = nn.Dropout(p_dropout)
|
| 168 |
+
|
| 169 |
+
if window_size is not None:
|
| 170 |
+
n_heads_rel = 1 if heads_share else n_heads
|
| 171 |
+
rel_stddev = self.k_channels**-0.5
|
| 172 |
+
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
| 173 |
+
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
| 174 |
+
|
| 175 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
| 176 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
| 177 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
| 178 |
+
if proximal_init:
|
| 179 |
+
with torch.no_grad():
|
| 180 |
+
self.conv_k.weight.copy_(self.conv_q.weight)
|
| 181 |
+
self.conv_k.bias.copy_(self.conv_q.bias)
|
| 182 |
+
|
| 183 |
+
def forward(self, x, c, attn_mask: Optional[torch.Tensor] = None):
|
| 184 |
+
q = self.conv_q(x)
|
| 185 |
+
k = self.conv_k(c)
|
| 186 |
+
v = self.conv_v(c)
|
| 187 |
+
|
| 188 |
+
# x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
| 189 |
+
x, _ = self.attention(q, k, v, mask=attn_mask)
|
| 190 |
+
|
| 191 |
+
x = self.conv_o(x)
|
| 192 |
+
return x
|
| 193 |
+
|
| 194 |
+
def attention(self, query, key, value, mask: Optional[torch.Tensor] = None):
|
| 195 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
| 196 |
+
b, d, t_s, _ = (*key.size(), query.size(2))
|
| 197 |
+
query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
| 198 |
+
key = key.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
| 199 |
+
value = value.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
| 200 |
+
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
| 201 |
+
|
| 202 |
+
if self.window_size is not None:
|
| 203 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
| 204 |
+
rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
|
| 205 |
+
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
| 206 |
+
scores = scores + scores_local
|
| 207 |
+
|
| 208 |
+
if mask is not None:
|
| 209 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
| 210 |
+
|
| 211 |
+
p_attn = F.softmax(scores, dim=-1)
|
| 212 |
+
p_attn = self.drop(p_attn)
|
| 213 |
+
output = torch.matmul(p_attn, value)
|
| 214 |
+
|
| 215 |
+
if self.window_size is not None:
|
| 216 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
| 217 |
+
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
| 218 |
+
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
| 219 |
+
|
| 220 |
+
output = output.transpose(2, 3).contiguous().view(b, d, -1)
|
| 221 |
+
return output, p_attn
|
| 222 |
+
|
| 223 |
+
def _matmul_with_relative_values(self, x, y):
|
| 224 |
+
"""
|
| 225 |
+
x: [b, h, l, m]
|
| 226 |
+
y: [h or 1, m, d]
|
| 227 |
+
ret: [b, h, l, d]
|
| 228 |
+
"""
|
| 229 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
| 230 |
+
return ret
|
| 231 |
+
|
| 232 |
+
def _matmul_with_relative_keys(self, x, y):
|
| 233 |
+
"""
|
| 234 |
+
x: [b, h, l, d]
|
| 235 |
+
y: [h or 1, m, d]
|
| 236 |
+
ret: [b, h, l, m]
|
| 237 |
+
"""
|
| 238 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
| 239 |
+
return ret
|
| 240 |
+
|
| 241 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
| 242 |
+
max_relative_position = 2 * self.window_size + 1
|
| 243 |
+
# Pad first before slice to avoid using cond ops.
|
| 244 |
+
pad_l = torch.zeros((1), dtype=torch.int64) + length - (self.window_size + 1)
|
| 245 |
+
pad_s = torch.zeros((1), dtype=torch.int64) + (self.window_size + 1) - length
|
| 246 |
+
pad_length = torch.max(pad_l, other=torch.zeros((1), dtype=torch.int64))
|
| 247 |
+
slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype=torch.int64))
|
| 248 |
+
|
| 249 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
| 250 |
+
padded_relative_embeddings = F.pad(
|
| 251 |
+
relative_embeddings,
|
| 252 |
+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
| 253 |
+
)
|
| 254 |
+
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
| 255 |
+
return used_relative_embeddings
|
| 256 |
+
|
| 257 |
+
def _relative_position_to_absolute_position(self, x):
|
| 258 |
+
"""
|
| 259 |
+
x: [b, h, l, 2*l-1]
|
| 260 |
+
ret: [b, h, l, l]
|
| 261 |
+
"""
|
| 262 |
+
batch, heads, length, _ = x.size()
|
| 263 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
| 264 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
| 265 |
+
|
| 266 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
| 267 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
| 268 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
| 269 |
+
|
| 270 |
+
# Reshape and slice out the padded elements.
|
| 271 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
|
| 272 |
+
return x_final
|
| 273 |
+
|
| 274 |
+
def _absolute_position_to_relative_position(self, x):
|
| 275 |
+
"""
|
| 276 |
+
x: [b, h, l, l]
|
| 277 |
+
ret: [b, h, l, 2*l-1]
|
| 278 |
+
"""
|
| 279 |
+
batch, heads, length, _ = x.size()
|
| 280 |
+
# padd along column
|
| 281 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
| 282 |
+
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
| 283 |
+
# add 0's in the beginning that will skew the elements after reshape
|
| 284 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
| 285 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
| 286 |
+
return x_final
|
| 287 |
+
|
| 288 |
+
def _attention_bias_proximal(self, length):
|
| 289 |
+
"""Bias for self-attention to encourage attention to close positions.
|
| 290 |
+
Args:
|
| 291 |
+
length: an integer scalar.
|
| 292 |
+
Returns:
|
| 293 |
+
a Tensor with shape [1, 1, length, length]
|
| 294 |
+
"""
|
| 295 |
+
r = torch.arange(length, dtype=torch.float32)
|
| 296 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
| 297 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class FFN(nn.Module):
|
| 301 |
+
def __init__(
|
| 302 |
+
self,
|
| 303 |
+
in_channels,
|
| 304 |
+
out_channels,
|
| 305 |
+
filter_channels,
|
| 306 |
+
kernel_size,
|
| 307 |
+
p_dropout=0.0,
|
| 308 |
+
activation="",
|
| 309 |
+
causal=False,
|
| 310 |
+
):
|
| 311 |
+
super().__init__()
|
| 312 |
+
self.in_channels = in_channels
|
| 313 |
+
self.out_channels = out_channels
|
| 314 |
+
self.filter_channels = filter_channels
|
| 315 |
+
self.kernel_size = kernel_size
|
| 316 |
+
self.p_dropout = p_dropout
|
| 317 |
+
self.activation = activation
|
| 318 |
+
self.causal = causal
|
| 319 |
+
|
| 320 |
+
# 从上下文看这里一定是 False
|
| 321 |
+
# if causal:
|
| 322 |
+
# self.padding = self._causal_padding
|
| 323 |
+
# else:
|
| 324 |
+
# self.padding = self._same_padding
|
| 325 |
+
|
| 326 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
|
| 327 |
+
self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
|
| 328 |
+
self.drop = nn.Dropout(p_dropout)
|
| 329 |
+
|
| 330 |
+
def forward(self, x, x_mask):
|
| 331 |
+
x = self.conv_1(self.padding(x * x_mask))
|
| 332 |
+
if self.activation == "gelu":
|
| 333 |
+
x = x * torch.sigmoid(1.702 * x)
|
| 334 |
+
else:
|
| 335 |
+
x = torch.relu(x)
|
| 336 |
+
x = self.drop(x)
|
| 337 |
+
x = self.conv_2(self.padding(x * x_mask))
|
| 338 |
+
return x * x_mask
|
| 339 |
+
|
| 340 |
+
def padding(self, x):
|
| 341 |
+
return self._same_padding(x)
|
| 342 |
+
|
| 343 |
+
def _causal_padding(self, x):
|
| 344 |
+
if self.kernel_size == 1:
|
| 345 |
+
return x
|
| 346 |
+
pad_l = self.kernel_size - 1
|
| 347 |
+
pad_r = 0
|
| 348 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
| 349 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
| 350 |
+
return x
|
| 351 |
+
|
| 352 |
+
def _same_padding(self, x):
|
| 353 |
+
if self.kernel_size == 1:
|
| 354 |
+
return x
|
| 355 |
+
pad_l = (self.kernel_size - 1) // 2
|
| 356 |
+
pad_r = self.kernel_size // 2
|
| 357 |
+
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
| 358 |
+
x = F.pad(x, commons.convert_pad_shape(padding))
|
| 359 |
+
return x
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
class MRTE(nn.Module):
|
| 363 |
+
def __init__(
|
| 364 |
+
self,
|
| 365 |
+
content_enc_channels=192,
|
| 366 |
+
hidden_size=512,
|
| 367 |
+
out_channels=192,
|
| 368 |
+
kernel_size=5,
|
| 369 |
+
n_heads=4,
|
| 370 |
+
ge_layer=2,
|
| 371 |
+
):
|
| 372 |
+
super(MRTE, self).__init__()
|
| 373 |
+
self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
|
| 374 |
+
self.c_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
|
| 375 |
+
self.text_pre = nn.Conv1d(content_enc_channels, hidden_size, 1)
|
| 376 |
+
self.c_post = nn.Conv1d(hidden_size, out_channels, 1)
|
| 377 |
+
|
| 378 |
+
def forward(self, ssl_enc, ssl_mask, text, text_mask, ge):
|
| 379 |
+
attn_mask = text_mask.unsqueeze(2) * ssl_mask.unsqueeze(-1)
|
| 380 |
+
|
| 381 |
+
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
|
| 382 |
+
text_enc = self.text_pre(text * text_mask)
|
| 383 |
+
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
|
| 384 |
+
x = self.c_post(x * ssl_mask)
|
| 385 |
+
return x
|
GPT_SoVITS/module/commons.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 7 |
+
classname = m.__class__.__name__
|
| 8 |
+
if classname.find("Conv") != -1:
|
| 9 |
+
m.weight.data.normal_(mean, std)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_padding(kernel_size, dilation=1):
|
| 13 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# def convert_pad_shape(pad_shape):
|
| 17 |
+
# l = pad_shape[::-1]
|
| 18 |
+
# pad_shape = [item for sublist in l for item in sublist]
|
| 19 |
+
# return pad_shape
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def intersperse(lst, item):
|
| 23 |
+
result = [item] * (len(lst) * 2 + 1)
|
| 24 |
+
result[1::2] = lst
|
| 25 |
+
return result
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
| 29 |
+
"""KL(P||Q)"""
|
| 30 |
+
kl = (logs_q - logs_p) - 0.5
|
| 31 |
+
kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
| 32 |
+
return kl
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def rand_gumbel(shape):
|
| 36 |
+
"""Sample from the Gumbel distribution, protect from overflows."""
|
| 37 |
+
uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
|
| 38 |
+
return -torch.log(-torch.log(uniform_samples))
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def rand_gumbel_like(x):
|
| 42 |
+
g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
|
| 43 |
+
return g
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def slice_segments(x, ids_str, segment_size=4):
|
| 47 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
| 48 |
+
for i in range(x.size(0)):
|
| 49 |
+
idx_str = ids_str[i]
|
| 50 |
+
idx_end = idx_str + segment_size
|
| 51 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
| 52 |
+
return ret
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
| 56 |
+
b, d, t = x.size()
|
| 57 |
+
if x_lengths is None:
|
| 58 |
+
x_lengths = t
|
| 59 |
+
ids_str_max = x_lengths - segment_size + 1
|
| 60 |
+
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
| 61 |
+
ret = slice_segments(x, ids_str, segment_size)
|
| 62 |
+
return ret, ids_str
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
| 66 |
+
position = torch.arange(length, dtype=torch.float)
|
| 67 |
+
num_timescales = channels // 2
|
| 68 |
+
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1)
|
| 69 |
+
inv_timescales = min_timescale * torch.exp(
|
| 70 |
+
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
| 71 |
+
)
|
| 72 |
+
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
| 73 |
+
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
| 74 |
+
signal = F.pad(signal, [0, 0, 0, channels % 2])
|
| 75 |
+
signal = signal.view(1, channels, length)
|
| 76 |
+
return signal
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
|
| 80 |
+
b, channels, length = x.size()
|
| 81 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
| 82 |
+
return x + signal.to(dtype=x.dtype, device=x.device)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
|
| 86 |
+
b, channels, length = x.size()
|
| 87 |
+
signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
|
| 88 |
+
return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def subsequent_mask(length):
|
| 92 |
+
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
| 93 |
+
return mask
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@torch.jit.script
|
| 97 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
| 98 |
+
n_channels_int = n_channels[0]
|
| 99 |
+
in_act = input_a + input_b
|
| 100 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
| 101 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
| 102 |
+
acts = t_act * s_act
|
| 103 |
+
return acts
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def convert_pad_shape(pad_shape):
|
| 107 |
+
l = pad_shape[::-1]
|
| 108 |
+
pad_shape = [item for sublist in l for item in sublist]
|
| 109 |
+
return pad_shape
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def shift_1d(x):
|
| 113 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
| 114 |
+
return x
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def sequence_mask(length, max_length=None):
|
| 118 |
+
if max_length is None:
|
| 119 |
+
max_length = length.max()
|
| 120 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
| 121 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def generate_path(duration, mask):
|
| 125 |
+
"""
|
| 126 |
+
duration: [b, 1, t_x]
|
| 127 |
+
mask: [b, 1, t_y, t_x]
|
| 128 |
+
"""
|
| 129 |
+
device = duration.device
|
| 130 |
+
|
| 131 |
+
b, _, t_y, t_x = mask.shape
|
| 132 |
+
cum_duration = torch.cumsum(duration, -1)
|
| 133 |
+
|
| 134 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
| 135 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
| 136 |
+
path = path.view(b, t_x, t_y)
|
| 137 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
| 138 |
+
path = path.unsqueeze(1).transpose(2, 3) * mask
|
| 139 |
+
return path
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
| 143 |
+
if isinstance(parameters, torch.Tensor):
|
| 144 |
+
parameters = [parameters]
|
| 145 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
| 146 |
+
norm_type = float(norm_type)
|
| 147 |
+
if clip_value is not None:
|
| 148 |
+
clip_value = float(clip_value)
|
| 149 |
+
|
| 150 |
+
total_norm = 0
|
| 151 |
+
for p in parameters:
|
| 152 |
+
param_norm = p.grad.data.norm(norm_type)
|
| 153 |
+
total_norm += param_norm.item() ** norm_type
|
| 154 |
+
if clip_value is not None:
|
| 155 |
+
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
| 156 |
+
total_norm = total_norm ** (1.0 / norm_type)
|
| 157 |
+
return total_norm
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def squeeze(x, x_mask=None, n_sqz=2):
|
| 161 |
+
b, c, t = x.size()
|
| 162 |
+
|
| 163 |
+
t = (t // n_sqz) * n_sqz
|
| 164 |
+
x = x[:, :, :t]
|
| 165 |
+
x_sqz = x.view(b, c, t // n_sqz, n_sqz)
|
| 166 |
+
x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
|
| 167 |
+
|
| 168 |
+
if x_mask is not None:
|
| 169 |
+
x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz]
|
| 170 |
+
else:
|
| 171 |
+
x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
|
| 172 |
+
return x_sqz * x_mask, x_mask
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def unsqueeze(x, x_mask=None, n_sqz=2):
|
| 176 |
+
b, c, t = x.size()
|
| 177 |
+
|
| 178 |
+
x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
|
| 179 |
+
x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
|
| 180 |
+
|
| 181 |
+
if x_mask is not None:
|
| 182 |
+
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
|
| 183 |
+
else:
|
| 184 |
+
x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
|
| 185 |
+
return x_unsqz * x_mask, x_mask
|
GPT_SoVITS/module/core_vq.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
# This implementation is inspired from
|
| 8 |
+
# https://github.com/lucidrains/vector-quantize-pytorch
|
| 9 |
+
# which is released under MIT License. Hereafter, the original license:
|
| 10 |
+
# MIT License
|
| 11 |
+
#
|
| 12 |
+
# Copyright (c) 2020 Phil Wang
|
| 13 |
+
#
|
| 14 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 15 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 16 |
+
# in the Software without restriction, including without limitation the rights
|
| 17 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 18 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 19 |
+
# furnished to do so, subject to the following conditions:
|
| 20 |
+
#
|
| 21 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 22 |
+
# copies or substantial portions of the Software.
|
| 23 |
+
#
|
| 24 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 25 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 26 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 27 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 28 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 29 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 30 |
+
# SOFTWARE.
|
| 31 |
+
|
| 32 |
+
"""Core vector quantization implementation."""
|
| 33 |
+
|
| 34 |
+
import typing as tp
|
| 35 |
+
|
| 36 |
+
from einops import rearrange, repeat
|
| 37 |
+
import torch
|
| 38 |
+
from torch import nn
|
| 39 |
+
import torch.nn.functional as F
|
| 40 |
+
from tqdm import tqdm
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
| 44 |
+
return val if val is not None else d
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def ema_inplace(moving_avg, new, decay: float):
|
| 48 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
| 52 |
+
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def uniform_init(*shape: int):
|
| 56 |
+
t = torch.empty(shape)
|
| 57 |
+
nn.init.kaiming_uniform_(t)
|
| 58 |
+
return t
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def sample_vectors(samples, num: int):
|
| 62 |
+
num_samples, device = samples.shape[0], samples.device
|
| 63 |
+
|
| 64 |
+
if num_samples >= num:
|
| 65 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
| 66 |
+
else:
|
| 67 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
| 68 |
+
|
| 69 |
+
return samples[indices]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
| 73 |
+
dim, dtype = samples.shape[-1], samples.dtype
|
| 74 |
+
max_kmeans_samples = 500
|
| 75 |
+
samples = samples[:max_kmeans_samples, :]
|
| 76 |
+
means = sample_vectors(samples, num_clusters)
|
| 77 |
+
|
| 78 |
+
print("kmeans start ... ")
|
| 79 |
+
for _ in tqdm(range(num_iters)):
|
| 80 |
+
diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
|
| 81 |
+
dists = -(diffs**2).sum(dim=-1)
|
| 82 |
+
|
| 83 |
+
buckets = dists.max(dim=-1).indices
|
| 84 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
| 85 |
+
zero_mask = bins == 0
|
| 86 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
| 87 |
+
|
| 88 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
| 89 |
+
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
| 90 |
+
new_means = new_means / bins_min_clamped[..., None]
|
| 91 |
+
|
| 92 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
| 93 |
+
|
| 94 |
+
return means, bins
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class EuclideanCodebook(nn.Module):
|
| 98 |
+
"""Codebook with Euclidean distance.
|
| 99 |
+
Args:
|
| 100 |
+
dim (int): Dimension.
|
| 101 |
+
codebook_size (int): Codebook size.
|
| 102 |
+
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
| 103 |
+
If set to true, run the k-means algorithm on the first training batch and use
|
| 104 |
+
the learned centroids as initialization.
|
| 105 |
+
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
| 106 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 107 |
+
epsilon (float): Epsilon value for numerical stability.
|
| 108 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 109 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
| 110 |
+
randomly selected vector from the current batch.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
dim: int,
|
| 116 |
+
codebook_size: int,
|
| 117 |
+
kmeans_init: int = False,
|
| 118 |
+
kmeans_iters: int = 10,
|
| 119 |
+
decay: float = 0.99,
|
| 120 |
+
epsilon: float = 1e-5,
|
| 121 |
+
threshold_ema_dead_code: int = 2,
|
| 122 |
+
):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.decay = decay
|
| 125 |
+
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
| 126 |
+
embed = init_fn(codebook_size, dim)
|
| 127 |
+
|
| 128 |
+
self.codebook_size = codebook_size
|
| 129 |
+
|
| 130 |
+
self.kmeans_iters = kmeans_iters
|
| 131 |
+
self.epsilon = epsilon
|
| 132 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 133 |
+
|
| 134 |
+
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
| 135 |
+
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
| 136 |
+
self.register_buffer("embed", embed)
|
| 137 |
+
self.register_buffer("embed_avg", embed.clone())
|
| 138 |
+
|
| 139 |
+
@torch.jit.ignore
|
| 140 |
+
def init_embed_(self, data):
|
| 141 |
+
if self.inited:
|
| 142 |
+
return
|
| 143 |
+
|
| 144 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
| 145 |
+
self.embed.data.copy_(embed)
|
| 146 |
+
self.embed_avg.data.copy_(embed.clone())
|
| 147 |
+
self.cluster_size.data.copy_(cluster_size)
|
| 148 |
+
self.inited.data.copy_(torch.Tensor([True]))
|
| 149 |
+
# Make sure all buffers across workers are in sync after initialization
|
| 150 |
+
# broadcast_tensors(self.buffers())
|
| 151 |
+
|
| 152 |
+
def replace_(self, samples, mask):
|
| 153 |
+
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
|
| 154 |
+
self.embed.data.copy_(modified_codebook)
|
| 155 |
+
|
| 156 |
+
def expire_codes_(self, batch_samples):
|
| 157 |
+
if self.threshold_ema_dead_code == 0:
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
| 161 |
+
if not torch.any(expired_codes):
|
| 162 |
+
return
|
| 163 |
+
|
| 164 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
| 165 |
+
self.replace_(batch_samples, mask=expired_codes)
|
| 166 |
+
# broadcast_tensors(self.buffers())
|
| 167 |
+
|
| 168 |
+
def preprocess(self, x):
|
| 169 |
+
x = rearrange(x, "... d -> (...) d")
|
| 170 |
+
return x
|
| 171 |
+
|
| 172 |
+
def quantize(self, x):
|
| 173 |
+
embed = self.embed.t()
|
| 174 |
+
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
|
| 175 |
+
embed_ind = dist.max(dim=-1).indices
|
| 176 |
+
return embed_ind
|
| 177 |
+
|
| 178 |
+
def postprocess_emb(self, embed_ind, shape):
|
| 179 |
+
return embed_ind.view(*shape[:-1])
|
| 180 |
+
|
| 181 |
+
def dequantize(self, embed_ind):
|
| 182 |
+
quantize = F.embedding(embed_ind, self.embed)
|
| 183 |
+
return quantize
|
| 184 |
+
|
| 185 |
+
def encode(self, x):
|
| 186 |
+
shape = x.shape
|
| 187 |
+
# pre-process
|
| 188 |
+
x = self.preprocess(x)
|
| 189 |
+
# quantize
|
| 190 |
+
embed_ind = self.quantize(x)
|
| 191 |
+
# post-process
|
| 192 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
| 193 |
+
return embed_ind
|
| 194 |
+
|
| 195 |
+
def decode(self, embed_ind):
|
| 196 |
+
quantize = self.dequantize(embed_ind)
|
| 197 |
+
return quantize
|
| 198 |
+
|
| 199 |
+
def forward(self, x):
|
| 200 |
+
shape, dtype = x.shape, x.dtype
|
| 201 |
+
x = self.preprocess(x)
|
| 202 |
+
|
| 203 |
+
self.init_embed_(x)
|
| 204 |
+
|
| 205 |
+
embed_ind = self.quantize(x)
|
| 206 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
| 207 |
+
embed_ind = self.postprocess_emb(embed_ind, shape)
|
| 208 |
+
quantize = self.dequantize(embed_ind)
|
| 209 |
+
|
| 210 |
+
if self.training:
|
| 211 |
+
# We do the expiry of code at that point as buffers are in sync
|
| 212 |
+
# and all the workers will take the same decision.
|
| 213 |
+
self.expire_codes_(x)
|
| 214 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
| 215 |
+
embed_sum = x.t() @ embed_onehot
|
| 216 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
| 217 |
+
cluster_size = (
|
| 218 |
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
|
| 219 |
+
)
|
| 220 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
| 221 |
+
self.embed.data.copy_(embed_normalized)
|
| 222 |
+
|
| 223 |
+
return quantize, embed_ind
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class VectorQuantization(nn.Module):
|
| 227 |
+
"""Vector quantization implementation.
|
| 228 |
+
Currently supports only euclidean distance.
|
| 229 |
+
Args:
|
| 230 |
+
dim (int): Dimension
|
| 231 |
+
codebook_size (int): Codebook size
|
| 232 |
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
| 233 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 234 |
+
epsilon (float): Epsilon value for numerical stability.
|
| 235 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
| 236 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
| 237 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 238 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
| 239 |
+
randomly selected vector from the current batch.
|
| 240 |
+
commitment_weight (float): Weight for commitment loss.
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
def __init__(
|
| 244 |
+
self,
|
| 245 |
+
dim: int,
|
| 246 |
+
codebook_size: int,
|
| 247 |
+
codebook_dim: tp.Optional[int] = None,
|
| 248 |
+
decay: float = 0.99,
|
| 249 |
+
epsilon: float = 1e-5,
|
| 250 |
+
kmeans_init: bool = True,
|
| 251 |
+
kmeans_iters: int = 50,
|
| 252 |
+
threshold_ema_dead_code: int = 2,
|
| 253 |
+
commitment_weight: float = 1.0,
|
| 254 |
+
):
|
| 255 |
+
super().__init__()
|
| 256 |
+
_codebook_dim: int = default(codebook_dim, dim)
|
| 257 |
+
|
| 258 |
+
requires_projection = _codebook_dim != dim
|
| 259 |
+
self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
| 260 |
+
self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
| 261 |
+
|
| 262 |
+
self.epsilon = epsilon
|
| 263 |
+
self.commitment_weight = commitment_weight
|
| 264 |
+
|
| 265 |
+
self._codebook = EuclideanCodebook(
|
| 266 |
+
dim=_codebook_dim,
|
| 267 |
+
codebook_size=codebook_size,
|
| 268 |
+
kmeans_init=kmeans_init,
|
| 269 |
+
kmeans_iters=kmeans_iters,
|
| 270 |
+
decay=decay,
|
| 271 |
+
epsilon=epsilon,
|
| 272 |
+
threshold_ema_dead_code=threshold_ema_dead_code,
|
| 273 |
+
)
|
| 274 |
+
self.codebook_size = codebook_size
|
| 275 |
+
|
| 276 |
+
@property
|
| 277 |
+
def codebook(self):
|
| 278 |
+
return self._codebook.embed
|
| 279 |
+
|
| 280 |
+
def encode(self, x):
|
| 281 |
+
x = rearrange(x, "b d n -> b n d")
|
| 282 |
+
x = self.project_in(x)
|
| 283 |
+
embed_in = self._codebook.encode(x)
|
| 284 |
+
return embed_in
|
| 285 |
+
|
| 286 |
+
def decode(self, embed_ind):
|
| 287 |
+
quantize = self._codebook.decode(embed_ind)
|
| 288 |
+
quantize = self.project_out(quantize)
|
| 289 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
| 290 |
+
return quantize
|
| 291 |
+
|
| 292 |
+
def forward(self, x):
|
| 293 |
+
device = x.device
|
| 294 |
+
x = rearrange(x, "b d n -> b n d")
|
| 295 |
+
x = self.project_in(x)
|
| 296 |
+
|
| 297 |
+
quantize, embed_ind = self._codebook(x)
|
| 298 |
+
|
| 299 |
+
if self.training:
|
| 300 |
+
quantize = x + (quantize - x).detach()
|
| 301 |
+
|
| 302 |
+
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
| 303 |
+
|
| 304 |
+
if self.training:
|
| 305 |
+
if self.commitment_weight > 0:
|
| 306 |
+
commit_loss = F.mse_loss(quantize.detach(), x)
|
| 307 |
+
loss = loss + commit_loss * self.commitment_weight
|
| 308 |
+
|
| 309 |
+
quantize = self.project_out(quantize)
|
| 310 |
+
quantize = rearrange(quantize, "b n d -> b d n")
|
| 311 |
+
return quantize, embed_ind, loss
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class ResidualVectorQuantization(nn.Module):
|
| 315 |
+
"""Residual vector quantization implementation.
|
| 316 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
def __init__(self, *, num_quantizers, **kwargs):
|
| 320 |
+
super().__init__()
|
| 321 |
+
self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
|
| 322 |
+
|
| 323 |
+
def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None):
|
| 324 |
+
quantized_out = 0.0
|
| 325 |
+
residual = x
|
| 326 |
+
|
| 327 |
+
all_losses = []
|
| 328 |
+
all_indices = []
|
| 329 |
+
out_quantized = []
|
| 330 |
+
|
| 331 |
+
n_q = n_q or len(self.layers)
|
| 332 |
+
|
| 333 |
+
for i, layer in enumerate(self.layers[:n_q]):
|
| 334 |
+
quantized, indices, loss = layer(residual)
|
| 335 |
+
residual = residual - quantized
|
| 336 |
+
quantized_out = quantized_out + quantized
|
| 337 |
+
|
| 338 |
+
all_indices.append(indices)
|
| 339 |
+
all_losses.append(loss)
|
| 340 |
+
if layers and i in layers:
|
| 341 |
+
out_quantized.append(quantized)
|
| 342 |
+
|
| 343 |
+
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
| 344 |
+
return quantized_out, out_indices, out_losses, out_quantized
|
| 345 |
+
|
| 346 |
+
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
|
| 347 |
+
residual = x
|
| 348 |
+
all_indices = []
|
| 349 |
+
n_q = n_q or len(self.layers)
|
| 350 |
+
st = st or 0
|
| 351 |
+
for layer in self.layers[st:n_q]:
|
| 352 |
+
indices = layer.encode(residual)
|
| 353 |
+
quantized = layer.decode(indices)
|
| 354 |
+
residual = residual - quantized
|
| 355 |
+
all_indices.append(indices)
|
| 356 |
+
out_indices = torch.stack(all_indices)
|
| 357 |
+
return out_indices
|
| 358 |
+
|
| 359 |
+
def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor:
|
| 360 |
+
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
| 361 |
+
for i, indices in enumerate(q_indices):
|
| 362 |
+
layer = self.layers[st + i]
|
| 363 |
+
quantized = layer.decode(indices)
|
| 364 |
+
quantized_out = quantized_out + quantized
|
| 365 |
+
return quantized_out
|
GPT_SoVITS/module/data_utils.py
ADDED
|
@@ -0,0 +1,1073 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import traceback
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch.utils.data
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from GPT_SoVITS.text import cleaned_text_to_sequence
|
| 11 |
+
from tools.my_utils import load_audio
|
| 12 |
+
|
| 13 |
+
from .mel_processing import spec_to_mel_torch, spectrogram_torch
|
| 14 |
+
|
| 15 |
+
version = os.environ.get("version", None)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
|
| 19 |
+
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
| 20 |
+
"""
|
| 21 |
+
1) loads audio, speaker_id, text pairs
|
| 22 |
+
2) normalizes text and converts them to sequences of integers
|
| 23 |
+
3) computes spectrograms from audio files.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, hparams, version=None, val=False):
|
| 27 |
+
exp_dir = hparams.exp_dir
|
| 28 |
+
self.path2 = "%s/2-name2text.txt" % exp_dir
|
| 29 |
+
self.path4 = "%s/4-cnhubert" % exp_dir
|
| 30 |
+
self.path5 = "%s/5-wav32k" % exp_dir
|
| 31 |
+
assert os.path.exists(self.path2)
|
| 32 |
+
assert os.path.exists(self.path4)
|
| 33 |
+
assert os.path.exists(self.path5)
|
| 34 |
+
self.is_v2Pro = version in {"v2Pro", "v2ProPlus"}
|
| 35 |
+
if self.is_v2Pro:
|
| 36 |
+
self.path7 = "%s/7-sv_cn" % exp_dir
|
| 37 |
+
assert os.path.exists(self.path7)
|
| 38 |
+
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
|
| 39 |
+
names5 = set(os.listdir(self.path5))
|
| 40 |
+
if self.is_v2Pro:
|
| 41 |
+
names6 = set([name[:-3] for name in list(os.listdir(self.path7))]) # 去除.pt后缀
|
| 42 |
+
self.phoneme_data = {}
|
| 43 |
+
with open(self.path2, "r", encoding="utf8") as f:
|
| 44 |
+
lines = f.read().strip("\n").split("\n")
|
| 45 |
+
|
| 46 |
+
for line in lines:
|
| 47 |
+
tmp = line.split("\t")
|
| 48 |
+
if len(tmp) != 4:
|
| 49 |
+
continue
|
| 50 |
+
self.phoneme_data[tmp[0]] = [tmp[1]]
|
| 51 |
+
if self.is_v2Pro:
|
| 52 |
+
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5 & names6)
|
| 53 |
+
else:
|
| 54 |
+
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
|
| 55 |
+
tmp = self.audiopaths_sid_text
|
| 56 |
+
leng = len(tmp)
|
| 57 |
+
min_num = 100
|
| 58 |
+
if leng < min_num:
|
| 59 |
+
self.audiopaths_sid_text = []
|
| 60 |
+
for _ in range(max(2, int(min_num / leng))):
|
| 61 |
+
self.audiopaths_sid_text += tmp
|
| 62 |
+
self.max_wav_value = hparams.max_wav_value
|
| 63 |
+
self.sampling_rate = hparams.sampling_rate
|
| 64 |
+
self.filter_length = hparams.filter_length
|
| 65 |
+
self.hop_length = hparams.hop_length
|
| 66 |
+
self.win_length = hparams.win_length
|
| 67 |
+
self.sampling_rate = hparams.sampling_rate
|
| 68 |
+
self.val = val
|
| 69 |
+
|
| 70 |
+
random.seed(1234)
|
| 71 |
+
random.shuffle(self.audiopaths_sid_text)
|
| 72 |
+
|
| 73 |
+
print("phoneme_data_len:", len(self.phoneme_data.keys()))
|
| 74 |
+
print("wav_data_len:", len(self.audiopaths_sid_text))
|
| 75 |
+
|
| 76 |
+
audiopaths_sid_text_new = []
|
| 77 |
+
lengths = []
|
| 78 |
+
skipped_phone = 0
|
| 79 |
+
skipped_dur = 0
|
| 80 |
+
for audiopath in tqdm(self.audiopaths_sid_text):
|
| 81 |
+
try:
|
| 82 |
+
phoneme = self.phoneme_data[audiopath][0]
|
| 83 |
+
phoneme = phoneme.split(" ")
|
| 84 |
+
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
| 85 |
+
except Exception:
|
| 86 |
+
print(f"{audiopath} not in self.phoneme_data !")
|
| 87 |
+
skipped_phone += 1
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
size = os.path.getsize("%s/%s" % (self.path5, audiopath))
|
| 91 |
+
duration = size / self.sampling_rate / 2
|
| 92 |
+
|
| 93 |
+
if duration == 0:
|
| 94 |
+
print(f"Zero duration for {audiopath}, skipping...")
|
| 95 |
+
skipped_dur += 1
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
if 54 > duration > 0.6 or self.val:
|
| 99 |
+
audiopaths_sid_text_new.append([audiopath, phoneme_ids])
|
| 100 |
+
lengths.append(size // (2 * self.hop_length))
|
| 101 |
+
else:
|
| 102 |
+
skipped_dur += 1
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
|
| 106 |
+
print("total left: ", len(audiopaths_sid_text_new))
|
| 107 |
+
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
|
| 108 |
+
self.audiopaths_sid_text = audiopaths_sid_text_new
|
| 109 |
+
self.lengths = lengths
|
| 110 |
+
|
| 111 |
+
def get_audio_text_speaker_pair(self, audiopath_sid_text):
|
| 112 |
+
audiopath, phoneme_ids = audiopath_sid_text
|
| 113 |
+
text = torch.FloatTensor(phoneme_ids)
|
| 114 |
+
try:
|
| 115 |
+
spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
| 118 |
+
if ssl.shape[-1] != spec.shape[-1]:
|
| 119 |
+
typee = ssl.dtype
|
| 120 |
+
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
| 121 |
+
ssl.requires_grad = False
|
| 122 |
+
if self.is_v2Pro:
|
| 123 |
+
sv_emb = torch.load("%s/%s.pt" % (self.path7, audiopath), map_location="cpu")
|
| 124 |
+
except:
|
| 125 |
+
traceback.print_exc()
|
| 126 |
+
spec = torch.zeros(1025, 100)
|
| 127 |
+
wav = torch.zeros(1, 100 * self.hop_length)
|
| 128 |
+
ssl = torch.zeros(1, 768, 100)
|
| 129 |
+
text = text[-1:]
|
| 130 |
+
if self.is_v2Pro:
|
| 131 |
+
sv_emb = torch.zeros(1, 20480)
|
| 132 |
+
print("load audio or ssl error!!!!!!", audiopath)
|
| 133 |
+
if self.is_v2Pro:
|
| 134 |
+
return (ssl, spec, wav, text, sv_emb)
|
| 135 |
+
else:
|
| 136 |
+
return (ssl, spec, wav, text)
|
| 137 |
+
|
| 138 |
+
def get_audio(self, filename):
|
| 139 |
+
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
| 140 |
+
audio = torch.FloatTensor(audio_array) # /32768
|
| 141 |
+
audio_norm = audio
|
| 142 |
+
audio_norm = audio_norm.unsqueeze(0)
|
| 143 |
+
spec = spectrogram_torch(
|
| 144 |
+
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
|
| 145 |
+
)
|
| 146 |
+
spec = torch.squeeze(spec, 0)
|
| 147 |
+
return spec, audio_norm
|
| 148 |
+
|
| 149 |
+
def get_sid(self, sid):
|
| 150 |
+
sid = torch.LongTensor([int(sid)])
|
| 151 |
+
return sid
|
| 152 |
+
|
| 153 |
+
def __getitem__(self, index):
|
| 154 |
+
# with torch.no_grad():
|
| 155 |
+
return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
|
| 156 |
+
|
| 157 |
+
def __len__(self):
|
| 158 |
+
return len(self.audiopaths_sid_text)
|
| 159 |
+
|
| 160 |
+
def random_slice(self, ssl, wav, mel):
|
| 161 |
+
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, ("first", ssl.shape, wav.shape)
|
| 162 |
+
|
| 163 |
+
len_mel = mel.shape[1]
|
| 164 |
+
if self.val:
|
| 165 |
+
reference_mel = mel[:, : len_mel // 3]
|
| 166 |
+
return reference_mel, ssl, wav, mel
|
| 167 |
+
dir = random.randint(0, 1)
|
| 168 |
+
sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
|
| 169 |
+
|
| 170 |
+
if dir == 0:
|
| 171 |
+
reference_mel = mel[:, :sep_point]
|
| 172 |
+
ssl = ssl[:, :, sep_point:]
|
| 173 |
+
wav2 = wav[:, sep_point * self.hop_length :]
|
| 174 |
+
mel = mel[:, sep_point:]
|
| 175 |
+
else:
|
| 176 |
+
reference_mel = mel[:, sep_point:]
|
| 177 |
+
ssl = ssl[:, :, :sep_point]
|
| 178 |
+
wav2 = wav[:, : sep_point * self.hop_length]
|
| 179 |
+
mel = mel[:, :sep_point]
|
| 180 |
+
|
| 181 |
+
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
|
| 182 |
+
ssl.shape,
|
| 183 |
+
wav.shape,
|
| 184 |
+
wav2.shape,
|
| 185 |
+
mel.shape,
|
| 186 |
+
sep_point,
|
| 187 |
+
self.hop_length,
|
| 188 |
+
sep_point * self.hop_length,
|
| 189 |
+
dir,
|
| 190 |
+
)
|
| 191 |
+
return reference_mel, ssl, wav2, mel
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class TextAudioSpeakerCollate:
|
| 195 |
+
"""Zero-pads model inputs and targets"""
|
| 196 |
+
|
| 197 |
+
def __init__(self, return_ids=False, version=None):
|
| 198 |
+
self.return_ids = return_ids
|
| 199 |
+
self.is_v2Pro = version in {"v2Pro", "v2ProPlus"}
|
| 200 |
+
|
| 201 |
+
def __call__(self, batch):
|
| 202 |
+
"""Collate's training batch from normalized text, audio and speaker identities
|
| 203 |
+
PARAMS
|
| 204 |
+
------
|
| 205 |
+
batch: [text_normalized, spec_normalized, wav_normalized, sid]
|
| 206 |
+
"""
|
| 207 |
+
# Right zero-pad all one-hot text sequences to max input length
|
| 208 |
+
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
|
| 209 |
+
|
| 210 |
+
max_ssl_len = max([x[0].size(2) for x in batch])
|
| 211 |
+
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
|
| 212 |
+
max_spec_len = max([x[1].size(1) for x in batch])
|
| 213 |
+
max_spec_len = int(2 * ((max_spec_len // 2) + 1))
|
| 214 |
+
max_wav_len = max([x[2].size(1) for x in batch])
|
| 215 |
+
max_text_len = max([x[3].size(0) for x in batch])
|
| 216 |
+
|
| 217 |
+
ssl_lengths = torch.LongTensor(len(batch))
|
| 218 |
+
spec_lengths = torch.LongTensor(len(batch))
|
| 219 |
+
wav_lengths = torch.LongTensor(len(batch))
|
| 220 |
+
text_lengths = torch.LongTensor(len(batch))
|
| 221 |
+
|
| 222 |
+
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
|
| 223 |
+
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
| 224 |
+
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
|
| 225 |
+
text_padded = torch.LongTensor(len(batch), max_text_len)
|
| 226 |
+
|
| 227 |
+
spec_padded.zero_()
|
| 228 |
+
wav_padded.zero_()
|
| 229 |
+
ssl_padded.zero_()
|
| 230 |
+
text_padded.zero_()
|
| 231 |
+
|
| 232 |
+
if self.is_v2Pro:
|
| 233 |
+
sv_embs = torch.FloatTensor(len(batch), 20480)
|
| 234 |
+
|
| 235 |
+
for i in range(len(ids_sorted_decreasing)):
|
| 236 |
+
row = batch[ids_sorted_decreasing[i]]
|
| 237 |
+
|
| 238 |
+
ssl = row[0]
|
| 239 |
+
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
|
| 240 |
+
ssl_lengths[i] = ssl.size(2)
|
| 241 |
+
|
| 242 |
+
spec = row[1]
|
| 243 |
+
spec_padded[i, :, : spec.size(1)] = spec
|
| 244 |
+
spec_lengths[i] = spec.size(1)
|
| 245 |
+
|
| 246 |
+
wav = row[2]
|
| 247 |
+
wav_padded[i, :, : wav.size(1)] = wav
|
| 248 |
+
wav_lengths[i] = wav.size(1)
|
| 249 |
+
|
| 250 |
+
text = row[3]
|
| 251 |
+
text_padded[i, : text.size(0)] = text
|
| 252 |
+
text_lengths[i] = text.size(0)
|
| 253 |
+
|
| 254 |
+
if self.is_v2Pro:
|
| 255 |
+
sv_embs[i] = row[4]
|
| 256 |
+
if self.is_v2Pro:
|
| 257 |
+
return (
|
| 258 |
+
ssl_padded,
|
| 259 |
+
ssl_lengths,
|
| 260 |
+
spec_padded,
|
| 261 |
+
spec_lengths,
|
| 262 |
+
wav_padded,
|
| 263 |
+
wav_lengths,
|
| 264 |
+
text_padded,
|
| 265 |
+
text_lengths,
|
| 266 |
+
sv_embs,
|
| 267 |
+
)
|
| 268 |
+
else:
|
| 269 |
+
return (
|
| 270 |
+
ssl_padded,
|
| 271 |
+
ssl_lengths,
|
| 272 |
+
spec_padded,
|
| 273 |
+
spec_lengths,
|
| 274 |
+
wav_padded,
|
| 275 |
+
wav_lengths,
|
| 276 |
+
text_padded,
|
| 277 |
+
text_lengths,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
| 282 |
+
"""
|
| 283 |
+
1) loads audio, speaker_id, text pairs
|
| 284 |
+
2) normalizes text and converts them to sequences of integers
|
| 285 |
+
3) computes spectrograms from audio files.
|
| 286 |
+
"""
|
| 287 |
+
|
| 288 |
+
def __init__(self, hparams, val=False):
|
| 289 |
+
exp_dir = hparams.exp_dir
|
| 290 |
+
self.path2 = "%s/2-name2text.txt" % exp_dir
|
| 291 |
+
self.path4 = "%s/4-cnhubert" % exp_dir
|
| 292 |
+
self.path5 = "%s/5-wav32k" % exp_dir
|
| 293 |
+
assert os.path.exists(self.path2)
|
| 294 |
+
assert os.path.exists(self.path4)
|
| 295 |
+
assert os.path.exists(self.path5)
|
| 296 |
+
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
|
| 297 |
+
names5 = set(os.listdir(self.path5))
|
| 298 |
+
self.phoneme_data = {}
|
| 299 |
+
with open(self.path2, "r", encoding="utf8") as f:
|
| 300 |
+
lines = f.read().strip("\n").split("\n")
|
| 301 |
+
|
| 302 |
+
for line in lines:
|
| 303 |
+
tmp = line.split("\t")
|
| 304 |
+
if len(tmp) != 4:
|
| 305 |
+
continue
|
| 306 |
+
self.phoneme_data[tmp[0]] = [tmp[1]]
|
| 307 |
+
|
| 308 |
+
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
|
| 309 |
+
tmp = self.audiopaths_sid_text
|
| 310 |
+
leng = len(tmp)
|
| 311 |
+
min_num = 100
|
| 312 |
+
if leng < min_num:
|
| 313 |
+
self.audiopaths_sid_text = []
|
| 314 |
+
for _ in range(max(2, int(min_num / leng))):
|
| 315 |
+
self.audiopaths_sid_text += tmp
|
| 316 |
+
self.max_wav_value = hparams.max_wav_value
|
| 317 |
+
self.sampling_rate = hparams.sampling_rate
|
| 318 |
+
self.filter_length = hparams.filter_length
|
| 319 |
+
self.hop_length = hparams.hop_length
|
| 320 |
+
self.win_length = hparams.win_length
|
| 321 |
+
self.sampling_rate = hparams.sampling_rate
|
| 322 |
+
self.val = val
|
| 323 |
+
|
| 324 |
+
random.seed(1234)
|
| 325 |
+
random.shuffle(self.audiopaths_sid_text)
|
| 326 |
+
|
| 327 |
+
print("phoneme_data_len:", len(self.phoneme_data.keys()))
|
| 328 |
+
print("wav_data_len:", len(self.audiopaths_sid_text))
|
| 329 |
+
|
| 330 |
+
audiopaths_sid_text_new = []
|
| 331 |
+
lengths = []
|
| 332 |
+
skipped_phone = 0
|
| 333 |
+
skipped_dur = 0
|
| 334 |
+
for audiopath in tqdm(self.audiopaths_sid_text):
|
| 335 |
+
try:
|
| 336 |
+
phoneme = self.phoneme_data[audiopath][0]
|
| 337 |
+
phoneme = phoneme.split(" ")
|
| 338 |
+
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
| 339 |
+
except Exception:
|
| 340 |
+
print(f"{audiopath} not in self.phoneme_data !")
|
| 341 |
+
skipped_phone += 1
|
| 342 |
+
continue
|
| 343 |
+
|
| 344 |
+
size = os.path.getsize("%s/%s" % (self.path5, audiopath))
|
| 345 |
+
duration = size / self.sampling_rate / 2
|
| 346 |
+
|
| 347 |
+
if duration == 0:
|
| 348 |
+
print(f"Zero duration for {audiopath}, skipping...")
|
| 349 |
+
skipped_dur += 1
|
| 350 |
+
continue
|
| 351 |
+
|
| 352 |
+
if 54 > duration > 0.6 or self.val:
|
| 353 |
+
audiopaths_sid_text_new.append([audiopath, phoneme_ids])
|
| 354 |
+
lengths.append(size // (2 * self.hop_length))
|
| 355 |
+
else:
|
| 356 |
+
skipped_dur += 1
|
| 357 |
+
continue
|
| 358 |
+
|
| 359 |
+
print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
|
| 360 |
+
print("total left: ", len(audiopaths_sid_text_new))
|
| 361 |
+
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
|
| 362 |
+
self.audiopaths_sid_text = audiopaths_sid_text_new
|
| 363 |
+
self.lengths = lengths
|
| 364 |
+
self.spec_min = -12
|
| 365 |
+
self.spec_max = 2
|
| 366 |
+
|
| 367 |
+
self.filter_length_mel = self.win_length_mel = 1024
|
| 368 |
+
self.hop_length_mel = 256
|
| 369 |
+
self.n_mel_channels = 100
|
| 370 |
+
self.sampling_rate_mel = 24000
|
| 371 |
+
self.mel_fmin = 0
|
| 372 |
+
self.mel_fmax = None
|
| 373 |
+
|
| 374 |
+
def norm_spec(self, x):
|
| 375 |
+
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
| 376 |
+
|
| 377 |
+
def get_audio_text_speaker_pair(self, audiopath_sid_text):
|
| 378 |
+
audiopath, phoneme_ids = audiopath_sid_text
|
| 379 |
+
text = torch.FloatTensor(phoneme_ids)
|
| 380 |
+
try:
|
| 381 |
+
spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath))
|
| 382 |
+
with torch.no_grad():
|
| 383 |
+
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
| 384 |
+
if ssl.shape[-1] != spec.shape[-1]:
|
| 385 |
+
typee = ssl.dtype
|
| 386 |
+
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
| 387 |
+
ssl.requires_grad = False
|
| 388 |
+
except:
|
| 389 |
+
traceback.print_exc()
|
| 390 |
+
mel = torch.zeros(100, 180)
|
| 391 |
+
# wav = torch.zeros(1, 96 * self.hop_length)
|
| 392 |
+
spec = torch.zeros(1025, 96)
|
| 393 |
+
ssl = torch.zeros(1, 768, 96)
|
| 394 |
+
text = text[-1:]
|
| 395 |
+
print("load audio or ssl error!!!!!!", audiopath)
|
| 396 |
+
return (ssl, spec, mel, text)
|
| 397 |
+
|
| 398 |
+
def get_audio(self, filename):
|
| 399 |
+
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
| 400 |
+
audio = torch.FloatTensor(audio_array) # /32768
|
| 401 |
+
audio_norm = audio
|
| 402 |
+
audio_norm = audio_norm.unsqueeze(0)
|
| 403 |
+
audio_array24 = load_audio(
|
| 404 |
+
filename, 24000
|
| 405 |
+
) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
|
| 406 |
+
audio24 = torch.FloatTensor(audio_array24) # /32768
|
| 407 |
+
audio_norm24 = audio24
|
| 408 |
+
audio_norm24 = audio_norm24.unsqueeze(0)
|
| 409 |
+
|
| 410 |
+
spec = spectrogram_torch(
|
| 411 |
+
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
|
| 412 |
+
)
|
| 413 |
+
spec = torch.squeeze(spec, 0)
|
| 414 |
+
|
| 415 |
+
spec1 = spectrogram_torch(
|
| 416 |
+
audio_norm24,
|
| 417 |
+
self.filter_length_mel,
|
| 418 |
+
self.sampling_rate_mel,
|
| 419 |
+
self.hop_length_mel,
|
| 420 |
+
self.win_length_mel,
|
| 421 |
+
center=False,
|
| 422 |
+
)
|
| 423 |
+
mel = spec_to_mel_torch(
|
| 424 |
+
spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax
|
| 425 |
+
)
|
| 426 |
+
mel = torch.squeeze(mel, 0)
|
| 427 |
+
mel = self.norm_spec(mel)
|
| 428 |
+
# print(1111111,spec.shape,mel.shape)
|
| 429 |
+
return spec, mel
|
| 430 |
+
|
| 431 |
+
def get_sid(self, sid):
|
| 432 |
+
sid = torch.LongTensor([int(sid)])
|
| 433 |
+
return sid
|
| 434 |
+
|
| 435 |
+
def __getitem__(self, index):
|
| 436 |
+
# with torch.no_grad():
|
| 437 |
+
return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
|
| 438 |
+
|
| 439 |
+
def __len__(self):
|
| 440 |
+
return len(self.audiopaths_sid_text)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
class TextAudioSpeakerCollateV3:
|
| 444 |
+
"""Zero-pads model inputs and targets"""
|
| 445 |
+
|
| 446 |
+
def __init__(self, return_ids=False):
|
| 447 |
+
self.return_ids = return_ids
|
| 448 |
+
|
| 449 |
+
def __call__(self, batch):
|
| 450 |
+
"""Collate's training batch from normalized text, audio and speaker identities
|
| 451 |
+
PARAMS
|
| 452 |
+
------
|
| 453 |
+
batch: [text_normalized, spec_normalized, wav_normalized, sid]
|
| 454 |
+
"""
|
| 455 |
+
# ssl, spec, wav,mel, text
|
| 456 |
+
# Right zero-pad all one-hot text sequences to max input length
|
| 457 |
+
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
|
| 458 |
+
# (ssl, spec,mel, text)
|
| 459 |
+
max_ssl_len = max([x[0].size(2) for x in batch])
|
| 460 |
+
|
| 461 |
+
max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1))
|
| 462 |
+
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
|
| 463 |
+
|
| 464 |
+
# max_ssl_len = int(8 * ((max_ssl_len // 8) + 1))
|
| 465 |
+
# max_ssl_len1=max_ssl_len
|
| 466 |
+
|
| 467 |
+
max_spec_len = max([x[1].size(1) for x in batch])
|
| 468 |
+
max_spec_len = int(2 * ((max_spec_len // 2) + 1))
|
| 469 |
+
# max_wav_len = max([x[2].size(1) for x in batch])
|
| 470 |
+
|
| 471 |
+
max_text_len = max([x[3].size(0) for x in batch])
|
| 472 |
+
max_mel_len = int(max_ssl_len1 * 1.25 * 1.5) ###24000/256,32000/640=16000/320
|
| 473 |
+
|
| 474 |
+
ssl_lengths = torch.LongTensor(len(batch))
|
| 475 |
+
spec_lengths = torch.LongTensor(len(batch))
|
| 476 |
+
text_lengths = torch.LongTensor(len(batch))
|
| 477 |
+
# wav_lengths = torch.LongTensor(len(batch))
|
| 478 |
+
mel_lengths = torch.LongTensor(len(batch))
|
| 479 |
+
|
| 480 |
+
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
|
| 481 |
+
mel_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), max_mel_len)
|
| 482 |
+
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
|
| 483 |
+
text_padded = torch.LongTensor(len(batch), max_text_len)
|
| 484 |
+
# wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
| 485 |
+
|
| 486 |
+
spec_padded.zero_()
|
| 487 |
+
mel_padded.zero_()
|
| 488 |
+
ssl_padded.zero_()
|
| 489 |
+
text_padded.zero_()
|
| 490 |
+
# wav_padded.zero_()
|
| 491 |
+
|
| 492 |
+
for i in range(len(ids_sorted_decreasing)):
|
| 493 |
+
row = batch[ids_sorted_decreasing[i]]
|
| 494 |
+
# ssl, spec, wav,mel, text
|
| 495 |
+
ssl = row[0]
|
| 496 |
+
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
|
| 497 |
+
ssl_lengths[i] = ssl.size(2)
|
| 498 |
+
|
| 499 |
+
spec = row[1]
|
| 500 |
+
spec_padded[i, :, : spec.size(1)] = spec
|
| 501 |
+
spec_lengths[i] = spec.size(1)
|
| 502 |
+
|
| 503 |
+
# wav = row[2]
|
| 504 |
+
# wav_padded[i, :, :wav.size(1)] = wav
|
| 505 |
+
# wav_lengths[i] = wav.size(1)
|
| 506 |
+
|
| 507 |
+
mel = row[2]
|
| 508 |
+
mel_padded[i, :, : mel.size(1)] = mel
|
| 509 |
+
mel_lengths[i] = mel.size(1)
|
| 510 |
+
|
| 511 |
+
text = row[3]
|
| 512 |
+
text_padded[i, : text.size(0)] = text
|
| 513 |
+
text_lengths[i] = text.size(0)
|
| 514 |
+
|
| 515 |
+
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
|
| 516 |
+
return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
class TextAudioSpeakerLoaderV4(torch.utils.data.Dataset):
|
| 520 |
+
"""
|
| 521 |
+
1) loads audio, speaker_id, text pairs
|
| 522 |
+
2) normalizes text and converts them to sequences of integers
|
| 523 |
+
3) computes spectrograms from audio files.
|
| 524 |
+
"""
|
| 525 |
+
|
| 526 |
+
def __init__(self, hparams, val=False):
|
| 527 |
+
exp_dir = hparams.exp_dir
|
| 528 |
+
self.path2 = "%s/2-name2text.txt" % exp_dir
|
| 529 |
+
self.path4 = "%s/4-cnhubert" % exp_dir
|
| 530 |
+
self.path5 = "%s/5-wav32k" % exp_dir
|
| 531 |
+
assert os.path.exists(self.path2)
|
| 532 |
+
assert os.path.exists(self.path4)
|
| 533 |
+
assert os.path.exists(self.path5)
|
| 534 |
+
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
|
| 535 |
+
names5 = set(os.listdir(self.path5))
|
| 536 |
+
self.phoneme_data = {}
|
| 537 |
+
with open(self.path2, "r", encoding="utf8") as f:
|
| 538 |
+
lines = f.read().strip("\n").split("\n")
|
| 539 |
+
|
| 540 |
+
for line in lines:
|
| 541 |
+
tmp = line.split("\t")
|
| 542 |
+
if len(tmp) != 4:
|
| 543 |
+
continue
|
| 544 |
+
self.phoneme_data[tmp[0]] = [tmp[1]]
|
| 545 |
+
|
| 546 |
+
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
|
| 547 |
+
tmp = self.audiopaths_sid_text
|
| 548 |
+
leng = len(tmp)
|
| 549 |
+
min_num = 100
|
| 550 |
+
if leng < min_num:
|
| 551 |
+
self.audiopaths_sid_text = []
|
| 552 |
+
for _ in range(max(2, int(min_num / leng))):
|
| 553 |
+
self.audiopaths_sid_text += tmp
|
| 554 |
+
self.max_wav_value = hparams.max_wav_value
|
| 555 |
+
self.sampling_rate = hparams.sampling_rate
|
| 556 |
+
self.filter_length = hparams.filter_length
|
| 557 |
+
self.hop_length = hparams.hop_length
|
| 558 |
+
self.win_length = hparams.win_length
|
| 559 |
+
self.sampling_rate = hparams.sampling_rate
|
| 560 |
+
self.val = val
|
| 561 |
+
|
| 562 |
+
random.seed(1234)
|
| 563 |
+
random.shuffle(self.audiopaths_sid_text)
|
| 564 |
+
|
| 565 |
+
print("phoneme_data_len:", len(self.phoneme_data.keys()))
|
| 566 |
+
print("wav_data_len:", len(self.audiopaths_sid_text))
|
| 567 |
+
|
| 568 |
+
audiopaths_sid_text_new = []
|
| 569 |
+
lengths = []
|
| 570 |
+
skipped_phone = 0
|
| 571 |
+
skipped_dur = 0
|
| 572 |
+
for audiopath in tqdm(self.audiopaths_sid_text):
|
| 573 |
+
try:
|
| 574 |
+
phoneme = self.phoneme_data[audiopath][0]
|
| 575 |
+
phoneme = phoneme.split(" ")
|
| 576 |
+
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
| 577 |
+
except Exception:
|
| 578 |
+
print(f"{audiopath} not in self.phoneme_data !")
|
| 579 |
+
skipped_phone += 1
|
| 580 |
+
continue
|
| 581 |
+
|
| 582 |
+
size = os.path.getsize("%s/%s" % (self.path5, audiopath))
|
| 583 |
+
duration = size / self.sampling_rate / 2
|
| 584 |
+
|
| 585 |
+
if duration == 0:
|
| 586 |
+
print(f"Zero duration for {audiopath}, skipping...")
|
| 587 |
+
skipped_dur += 1
|
| 588 |
+
continue
|
| 589 |
+
|
| 590 |
+
if 54 > duration > 0.6 or self.val:
|
| 591 |
+
audiopaths_sid_text_new.append([audiopath, phoneme_ids])
|
| 592 |
+
lengths.append(size // (2 * self.hop_length))
|
| 593 |
+
else:
|
| 594 |
+
skipped_dur += 1
|
| 595 |
+
continue
|
| 596 |
+
|
| 597 |
+
print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
|
| 598 |
+
print("total left: ", len(audiopaths_sid_text_new))
|
| 599 |
+
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
|
| 600 |
+
self.audiopaths_sid_text = audiopaths_sid_text_new
|
| 601 |
+
self.lengths = lengths
|
| 602 |
+
self.spec_min = -12
|
| 603 |
+
self.spec_max = 2
|
| 604 |
+
|
| 605 |
+
self.filter_length_mel = self.win_length_mel = 1280
|
| 606 |
+
self.hop_length_mel = 320
|
| 607 |
+
self.n_mel_channels = 100
|
| 608 |
+
self.sampling_rate_mel = 32000
|
| 609 |
+
self.mel_fmin = 0
|
| 610 |
+
self.mel_fmax = None
|
| 611 |
+
|
| 612 |
+
def norm_spec(self, x):
|
| 613 |
+
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
| 614 |
+
|
| 615 |
+
def get_audio_text_speaker_pair(self, audiopath_sid_text):
|
| 616 |
+
audiopath, phoneme_ids = audiopath_sid_text
|
| 617 |
+
text = torch.FloatTensor(phoneme_ids)
|
| 618 |
+
try:
|
| 619 |
+
spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath))
|
| 620 |
+
with torch.no_grad():
|
| 621 |
+
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
| 622 |
+
if ssl.shape[-1] != spec.shape[-1]:
|
| 623 |
+
typee = ssl.dtype
|
| 624 |
+
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
| 625 |
+
ssl.requires_grad = False
|
| 626 |
+
except:
|
| 627 |
+
traceback.print_exc()
|
| 628 |
+
mel = torch.zeros(100, 192)
|
| 629 |
+
# wav = torch.zeros(1, 96 * self.hop_length)
|
| 630 |
+
spec = torch.zeros(1025, 96)
|
| 631 |
+
ssl = torch.zeros(1, 768, 96)
|
| 632 |
+
text = text[-1:]
|
| 633 |
+
print("load audio or ssl error!!!!!!", audiopath)
|
| 634 |
+
return (ssl, spec, mel, text)
|
| 635 |
+
|
| 636 |
+
def get_audio(self, filename):
|
| 637 |
+
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
| 638 |
+
audio = torch.FloatTensor(audio_array) # /32768
|
| 639 |
+
audio_norm = audio
|
| 640 |
+
audio_norm = audio_norm.unsqueeze(0)
|
| 641 |
+
spec = spectrogram_torch(
|
| 642 |
+
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
|
| 643 |
+
)
|
| 644 |
+
spec = torch.squeeze(spec, 0)
|
| 645 |
+
spec1 = spectrogram_torch(audio_norm, 1280, 32000, 320, 1280, center=False)
|
| 646 |
+
mel = spec_to_mel_torch(spec1, 1280, 100, 32000, 0, None)
|
| 647 |
+
mel = self.norm_spec(torch.squeeze(mel, 0))
|
| 648 |
+
return spec, mel
|
| 649 |
+
|
| 650 |
+
def get_sid(self, sid):
|
| 651 |
+
sid = torch.LongTensor([int(sid)])
|
| 652 |
+
return sid
|
| 653 |
+
|
| 654 |
+
def __getitem__(self, index):
|
| 655 |
+
# with torch.no_grad():
|
| 656 |
+
return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
|
| 657 |
+
|
| 658 |
+
def __len__(self):
|
| 659 |
+
return len(self.audiopaths_sid_text)
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
class TextAudioSpeakerCollateV4:
|
| 663 |
+
"""Zero-pads model inputs and targets"""
|
| 664 |
+
|
| 665 |
+
def __init__(self, return_ids=False):
|
| 666 |
+
self.return_ids = return_ids
|
| 667 |
+
|
| 668 |
+
def __call__(self, batch):
|
| 669 |
+
"""Collate's training batch from normalized text, audio and speaker identities
|
| 670 |
+
PARAMS
|
| 671 |
+
------
|
| 672 |
+
batch: [text_normalized, spec_normalized, wav_normalized, sid]
|
| 673 |
+
"""
|
| 674 |
+
# ssl, spec, wav,mel, text
|
| 675 |
+
# Right zero-pad all one-hot text sequences to max input length
|
| 676 |
+
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
|
| 677 |
+
# (ssl, spec,mel, text)
|
| 678 |
+
max_ssl_len = max([x[0].size(2) for x in batch])
|
| 679 |
+
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
|
| 680 |
+
max_spec_len = max([x[1].size(1) for x in batch])
|
| 681 |
+
max_spec_len = int(2 * ((max_spec_len // 2) + 1))
|
| 682 |
+
# max_wav_len = max([x[2].size(1) for x in batch])
|
| 683 |
+
max_text_len = max([x[3].size(0) for x in batch])
|
| 684 |
+
|
| 685 |
+
ssl_lengths = torch.LongTensor(len(batch))
|
| 686 |
+
spec_lengths = torch.LongTensor(len(batch))
|
| 687 |
+
text_lengths = torch.LongTensor(len(batch))
|
| 688 |
+
# wav_lengths = torch.LongTensor(len(batch))
|
| 689 |
+
mel_lengths = torch.LongTensor(len(batch))
|
| 690 |
+
|
| 691 |
+
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
|
| 692 |
+
mel_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), max_spec_len * 2)
|
| 693 |
+
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
|
| 694 |
+
text_padded = torch.LongTensor(len(batch), max_text_len)
|
| 695 |
+
# wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
| 696 |
+
|
| 697 |
+
spec_padded.zero_()
|
| 698 |
+
mel_padded.zero_()
|
| 699 |
+
ssl_padded.zero_()
|
| 700 |
+
text_padded.zero_()
|
| 701 |
+
# wav_padded.zero_()
|
| 702 |
+
|
| 703 |
+
for i in range(len(ids_sorted_decreasing)):
|
| 704 |
+
row = batch[ids_sorted_decreasing[i]]
|
| 705 |
+
# ssl, spec, wav,mel, text
|
| 706 |
+
ssl = row[0]
|
| 707 |
+
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
|
| 708 |
+
ssl_lengths[i] = ssl.size(2)
|
| 709 |
+
|
| 710 |
+
spec = row[1]
|
| 711 |
+
spec_padded[i, :, : spec.size(1)] = spec
|
| 712 |
+
spec_lengths[i] = spec.size(1)
|
| 713 |
+
|
| 714 |
+
# wav = row[2]
|
| 715 |
+
# wav_padded[i, :, :wav.size(1)] = wav
|
| 716 |
+
# wav_lengths[i] = wav.size(1)
|
| 717 |
+
|
| 718 |
+
mel = row[2]
|
| 719 |
+
mel_padded[i, :, : mel.size(1)] = mel
|
| 720 |
+
mel_lengths[i] = mel.size(1)
|
| 721 |
+
|
| 722 |
+
text = row[3]
|
| 723 |
+
text_padded[i, : text.size(0)] = text
|
| 724 |
+
text_lengths[i] = text.size(0)
|
| 725 |
+
|
| 726 |
+
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
|
| 727 |
+
return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
| 731 |
+
"""
|
| 732 |
+
1) loads audio, speaker_id, text pairs
|
| 733 |
+
2) normalizes text and converts them to sequences of integers
|
| 734 |
+
3) computes spectrograms from audio files.
|
| 735 |
+
"""
|
| 736 |
+
|
| 737 |
+
def __init__(self, hparams, val=False):
|
| 738 |
+
exp_dir = hparams.exp_dir
|
| 739 |
+
self.path2 = "%s/2-name2text.txt" % exp_dir
|
| 740 |
+
self.path4 = "%s/4-cnhubert" % exp_dir
|
| 741 |
+
self.path5 = "%s/5-wav32k" % exp_dir
|
| 742 |
+
assert os.path.exists(self.path2)
|
| 743 |
+
assert os.path.exists(self.path4)
|
| 744 |
+
assert os.path.exists(self.path5)
|
| 745 |
+
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
|
| 746 |
+
names5 = set(os.listdir(self.path5))
|
| 747 |
+
self.phoneme_data = {}
|
| 748 |
+
with open(self.path2, "r", encoding="utf8") as f:
|
| 749 |
+
lines = f.read().strip("\n").split("\n")
|
| 750 |
+
|
| 751 |
+
for line in lines:
|
| 752 |
+
tmp = line.split("\t")
|
| 753 |
+
if len(tmp) != 4:
|
| 754 |
+
continue
|
| 755 |
+
self.phoneme_data[tmp[0]] = [tmp[1]]
|
| 756 |
+
|
| 757 |
+
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
|
| 758 |
+
tmp = self.audiopaths_sid_text
|
| 759 |
+
leng = len(tmp)
|
| 760 |
+
min_num = 100
|
| 761 |
+
if leng < min_num:
|
| 762 |
+
self.audiopaths_sid_text = []
|
| 763 |
+
for _ in range(max(2, int(min_num / leng))):
|
| 764 |
+
self.audiopaths_sid_text += tmp
|
| 765 |
+
self.max_wav_value = hparams.max_wav_value
|
| 766 |
+
self.sampling_rate = hparams.sampling_rate
|
| 767 |
+
self.filter_length = hparams.filter_length
|
| 768 |
+
self.hop_length = hparams.hop_length
|
| 769 |
+
self.win_length = hparams.win_length
|
| 770 |
+
self.sampling_rate = hparams.sampling_rate
|
| 771 |
+
self.val = val
|
| 772 |
+
|
| 773 |
+
random.seed(1234)
|
| 774 |
+
random.shuffle(self.audiopaths_sid_text)
|
| 775 |
+
|
| 776 |
+
print("phoneme_data_len:", len(self.phoneme_data.keys()))
|
| 777 |
+
print("wav_data_len:", len(self.audiopaths_sid_text))
|
| 778 |
+
|
| 779 |
+
audiopaths_sid_text_new = []
|
| 780 |
+
lengths = []
|
| 781 |
+
skipped_phone = 0
|
| 782 |
+
skipped_dur = 0
|
| 783 |
+
for audiopath in tqdm(self.audiopaths_sid_text):
|
| 784 |
+
try:
|
| 785 |
+
phoneme = self.phoneme_data[audiopath][0]
|
| 786 |
+
phoneme = phoneme.split(" ")
|
| 787 |
+
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
| 788 |
+
except Exception:
|
| 789 |
+
print(f"{audiopath} not in self.phoneme_data !")
|
| 790 |
+
skipped_phone += 1
|
| 791 |
+
continue
|
| 792 |
+
|
| 793 |
+
size = os.path.getsize("%s/%s" % (self.path5, audiopath))
|
| 794 |
+
duration = size / self.sampling_rate / 2
|
| 795 |
+
|
| 796 |
+
if duration == 0:
|
| 797 |
+
print(f"Zero duration for {audiopath}, skipping...")
|
| 798 |
+
skipped_dur += 1
|
| 799 |
+
continue
|
| 800 |
+
|
| 801 |
+
if 54 > duration > 0.6 or self.val:
|
| 802 |
+
audiopaths_sid_text_new.append([audiopath, phoneme_ids])
|
| 803 |
+
lengths.append(size // (2 * self.hop_length))
|
| 804 |
+
else:
|
| 805 |
+
skipped_dur += 1
|
| 806 |
+
continue
|
| 807 |
+
|
| 808 |
+
print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
|
| 809 |
+
print("total left: ", len(audiopaths_sid_text_new))
|
| 810 |
+
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
|
| 811 |
+
self.audiopaths_sid_text = audiopaths_sid_text_new
|
| 812 |
+
self.lengths = lengths
|
| 813 |
+
self.spec_min = -12
|
| 814 |
+
self.spec_max = 2
|
| 815 |
+
|
| 816 |
+
self.filter_length_mel = self.win_length_mel = 1024
|
| 817 |
+
self.hop_length_mel = 256
|
| 818 |
+
self.n_mel_channels = 100
|
| 819 |
+
self.sampling_rate_mel = 24000
|
| 820 |
+
self.mel_fmin = 0
|
| 821 |
+
self.mel_fmax = None
|
| 822 |
+
|
| 823 |
+
def norm_spec(self, x):
|
| 824 |
+
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
| 825 |
+
|
| 826 |
+
def get_audio_text_speaker_pair(self, audiopath_sid_text):
|
| 827 |
+
audiopath, phoneme_ids = audiopath_sid_text
|
| 828 |
+
text = torch.FloatTensor(phoneme_ids)
|
| 829 |
+
try:
|
| 830 |
+
spec, mel, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
| 831 |
+
with torch.no_grad():
|
| 832 |
+
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
| 833 |
+
if ssl.shape[-1] != spec.shape[-1]:
|
| 834 |
+
typee = ssl.dtype
|
| 835 |
+
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
| 836 |
+
ssl.requires_grad = False
|
| 837 |
+
except:
|
| 838 |
+
traceback.print_exc()
|
| 839 |
+
mel = torch.zeros(100, 180)
|
| 840 |
+
wav = torch.zeros(1, 96 * self.hop_length)
|
| 841 |
+
spec = torch.zeros(1025, 96)
|
| 842 |
+
ssl = torch.zeros(1, 768, 96)
|
| 843 |
+
text = text[-1:]
|
| 844 |
+
print("load audio or ssl error!!!!!!", audiopath)
|
| 845 |
+
return (ssl, spec, wav, mel, text)
|
| 846 |
+
|
| 847 |
+
def get_audio(self, filename):
|
| 848 |
+
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
| 849 |
+
audio = torch.FloatTensor(audio_array) # /32768
|
| 850 |
+
audio_norm = audio
|
| 851 |
+
audio_norm = audio_norm.unsqueeze(0)
|
| 852 |
+
audio_array24 = load_audio(
|
| 853 |
+
filename, 24000
|
| 854 |
+
) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
|
| 855 |
+
audio24 = torch.FloatTensor(audio_array24) # /32768
|
| 856 |
+
audio_norm24 = audio24
|
| 857 |
+
audio_norm24 = audio_norm24.unsqueeze(0)
|
| 858 |
+
|
| 859 |
+
spec = spectrogram_torch(
|
| 860 |
+
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
|
| 861 |
+
)
|
| 862 |
+
spec = torch.squeeze(spec, 0)
|
| 863 |
+
|
| 864 |
+
spec1 = spectrogram_torch(
|
| 865 |
+
audio_norm24,
|
| 866 |
+
self.filter_length_mel,
|
| 867 |
+
self.sampling_rate_mel,
|
| 868 |
+
self.hop_length_mel,
|
| 869 |
+
self.win_length_mel,
|
| 870 |
+
center=False,
|
| 871 |
+
)
|
| 872 |
+
mel = spec_to_mel_torch(
|
| 873 |
+
spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax
|
| 874 |
+
)
|
| 875 |
+
mel = torch.squeeze(mel, 0)
|
| 876 |
+
mel = self.norm_spec(mel)
|
| 877 |
+
# print(1111111,spec.shape,mel.shape)
|
| 878 |
+
return spec, mel, audio_norm
|
| 879 |
+
|
| 880 |
+
def get_sid(self, sid):
|
| 881 |
+
sid = torch.LongTensor([int(sid)])
|
| 882 |
+
return sid
|
| 883 |
+
|
| 884 |
+
def __getitem__(self, index):
|
| 885 |
+
# with torch.no_grad():
|
| 886 |
+
return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
|
| 887 |
+
|
| 888 |
+
def __len__(self):
|
| 889 |
+
return len(self.audiopaths_sid_text)
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
class TextAudioSpeakerCollateV3b:
|
| 893 |
+
"""Zero-pads model inputs and targets"""
|
| 894 |
+
|
| 895 |
+
def __init__(self, return_ids=False):
|
| 896 |
+
self.return_ids = return_ids
|
| 897 |
+
|
| 898 |
+
def __call__(self, batch):
|
| 899 |
+
"""Collate's training batch from normalized text, audio and speaker identities
|
| 900 |
+
PARAMS
|
| 901 |
+
------
|
| 902 |
+
batch: [text_normalized, spec_normalized, wav_normalized, sid]
|
| 903 |
+
"""
|
| 904 |
+
# ssl, spec, wav,mel, text
|
| 905 |
+
# Right zero-pad all one-hot text sequences to max input length
|
| 906 |
+
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
|
| 907 |
+
# (ssl, spec,mel, text)
|
| 908 |
+
max_ssl_len = max([x[0].size(2) for x in batch])
|
| 909 |
+
|
| 910 |
+
max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1))
|
| 911 |
+
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
|
| 912 |
+
|
| 913 |
+
# max_ssl_len = int(8 * ((max_ssl_len // 8) + 1))
|
| 914 |
+
# max_ssl_len1=max_ssl_len
|
| 915 |
+
|
| 916 |
+
max_spec_len = max([x[1].size(1) for x in batch])
|
| 917 |
+
max_spec_len = int(2 * ((max_spec_len // 2) + 1))
|
| 918 |
+
max_wav_len = max([x[2].size(1) for x in batch])
|
| 919 |
+
max_text_len = max([x[4].size(0) for x in batch])
|
| 920 |
+
max_mel_len = int(max_ssl_len1 * 1.25 * 1.5) ###24000/256,32000/640=16000/320
|
| 921 |
+
|
| 922 |
+
ssl_lengths = torch.LongTensor(len(batch))
|
| 923 |
+
spec_lengths = torch.LongTensor(len(batch))
|
| 924 |
+
text_lengths = torch.LongTensor(len(batch))
|
| 925 |
+
wav_lengths = torch.LongTensor(len(batch))
|
| 926 |
+
mel_lengths = torch.LongTensor(len(batch))
|
| 927 |
+
|
| 928 |
+
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
|
| 929 |
+
mel_padded = torch.FloatTensor(len(batch), batch[0][3].size(0), max_mel_len)
|
| 930 |
+
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
|
| 931 |
+
text_padded = torch.LongTensor(len(batch), max_text_len)
|
| 932 |
+
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
| 933 |
+
|
| 934 |
+
spec_padded.zero_()
|
| 935 |
+
mel_padded.zero_()
|
| 936 |
+
ssl_padded.zero_()
|
| 937 |
+
text_padded.zero_()
|
| 938 |
+
wav_padded.zero_()
|
| 939 |
+
|
| 940 |
+
for i in range(len(ids_sorted_decreasing)):
|
| 941 |
+
row = batch[ids_sorted_decreasing[i]]
|
| 942 |
+
# ssl, spec, wav,mel, text
|
| 943 |
+
ssl = row[0]
|
| 944 |
+
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
|
| 945 |
+
ssl_lengths[i] = ssl.size(2)
|
| 946 |
+
|
| 947 |
+
spec = row[1]
|
| 948 |
+
spec_padded[i, :, : spec.size(1)] = spec
|
| 949 |
+
spec_lengths[i] = spec.size(1)
|
| 950 |
+
|
| 951 |
+
wav = row[2]
|
| 952 |
+
wav_padded[i, :, : wav.size(1)] = wav
|
| 953 |
+
wav_lengths[i] = wav.size(1)
|
| 954 |
+
|
| 955 |
+
mel = row[3]
|
| 956 |
+
mel_padded[i, :, : mel.size(1)] = mel
|
| 957 |
+
mel_lengths[i] = mel.size(1)
|
| 958 |
+
|
| 959 |
+
text = row[4]
|
| 960 |
+
text_padded[i, : text.size(0)] = text
|
| 961 |
+
text_lengths[i] = text.size(0)
|
| 962 |
+
|
| 963 |
+
return (
|
| 964 |
+
ssl_padded,
|
| 965 |
+
spec_padded,
|
| 966 |
+
mel_padded,
|
| 967 |
+
ssl_lengths,
|
| 968 |
+
spec_lengths,
|
| 969 |
+
text_padded,
|
| 970 |
+
text_lengths,
|
| 971 |
+
wav_padded,
|
| 972 |
+
wav_lengths,
|
| 973 |
+
mel_lengths,
|
| 974 |
+
)
|
| 975 |
+
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths
|
| 976 |
+
|
| 977 |
+
|
| 978 |
+
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
| 979 |
+
"""
|
| 980 |
+
Maintain similar input lengths in a batch.
|
| 981 |
+
Length groups are specified by boundaries.
|
| 982 |
+
Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
|
| 983 |
+
|
| 984 |
+
It removes samples which are not included in the boundaries.
|
| 985 |
+
Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
|
| 986 |
+
"""
|
| 987 |
+
|
| 988 |
+
def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
|
| 989 |
+
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
| 990 |
+
self.lengths = dataset.lengths
|
| 991 |
+
self.batch_size = batch_size
|
| 992 |
+
self.boundaries = boundaries
|
| 993 |
+
|
| 994 |
+
self.buckets, self.num_samples_per_bucket = self._create_buckets()
|
| 995 |
+
self.total_size = sum(self.num_samples_per_bucket)
|
| 996 |
+
self.num_samples = self.total_size // self.num_replicas
|
| 997 |
+
|
| 998 |
+
def _create_buckets(self):
|
| 999 |
+
buckets = [[] for _ in range(len(self.boundaries) - 1)]
|
| 1000 |
+
for i in range(len(self.lengths)):
|
| 1001 |
+
length = self.lengths[i]
|
| 1002 |
+
idx_bucket = self._bisect(length)
|
| 1003 |
+
if idx_bucket != -1:
|
| 1004 |
+
buckets[idx_bucket].append(i)
|
| 1005 |
+
|
| 1006 |
+
i = len(buckets) - 1
|
| 1007 |
+
while i >= 0:
|
| 1008 |
+
if len(buckets[i]) == 0:
|
| 1009 |
+
buckets.pop(i)
|
| 1010 |
+
self.boundaries.pop(i + 1)
|
| 1011 |
+
i -= 1
|
| 1012 |
+
|
| 1013 |
+
num_samples_per_bucket = []
|
| 1014 |
+
for i in range(len(buckets)):
|
| 1015 |
+
len_bucket = len(buckets[i])
|
| 1016 |
+
total_batch_size = self.num_replicas * self.batch_size
|
| 1017 |
+
rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
|
| 1018 |
+
num_samples_per_bucket.append(len_bucket + rem)
|
| 1019 |
+
return buckets, num_samples_per_bucket
|
| 1020 |
+
|
| 1021 |
+
def __iter__(self):
|
| 1022 |
+
g = torch.Generator()
|
| 1023 |
+
g.manual_seed(self.epoch)
|
| 1024 |
+
|
| 1025 |
+
indices = []
|
| 1026 |
+
if self.shuffle:
|
| 1027 |
+
for bucket in self.buckets:
|
| 1028 |
+
indices.append(torch.randperm(len(bucket), generator=g).tolist())
|
| 1029 |
+
else:
|
| 1030 |
+
for bucket in self.buckets:
|
| 1031 |
+
indices.append(list(range(len(bucket))))
|
| 1032 |
+
|
| 1033 |
+
batches = []
|
| 1034 |
+
for i in range(len(self.buckets)):
|
| 1035 |
+
bucket = self.buckets[i]
|
| 1036 |
+
len_bucket = len(bucket)
|
| 1037 |
+
ids_bucket = indices[i]
|
| 1038 |
+
num_samples_bucket = self.num_samples_per_bucket[i]
|
| 1039 |
+
|
| 1040 |
+
rem = num_samples_bucket - len_bucket
|
| 1041 |
+
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)]
|
| 1042 |
+
|
| 1043 |
+
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
|
| 1044 |
+
|
| 1045 |
+
for j in range(len(ids_bucket) // self.batch_size):
|
| 1046 |
+
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]]
|
| 1047 |
+
batches.append(batch)
|
| 1048 |
+
|
| 1049 |
+
if self.shuffle:
|
| 1050 |
+
batch_ids = torch.randperm(len(batches), generator=g).tolist()
|
| 1051 |
+
batches = [batches[i] for i in batch_ids]
|
| 1052 |
+
self.batches = batches
|
| 1053 |
+
|
| 1054 |
+
assert len(self.batches) * self.batch_size == self.num_samples
|
| 1055 |
+
return iter(self.batches)
|
| 1056 |
+
|
| 1057 |
+
def _bisect(self, x, lo=0, hi=None):
|
| 1058 |
+
if hi is None:
|
| 1059 |
+
hi = len(self.boundaries) - 1
|
| 1060 |
+
|
| 1061 |
+
if hi > lo:
|
| 1062 |
+
mid = (hi + lo) // 2
|
| 1063 |
+
if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
|
| 1064 |
+
return mid
|
| 1065 |
+
elif x <= self.boundaries[mid]:
|
| 1066 |
+
return self._bisect(x, lo, mid)
|
| 1067 |
+
else:
|
| 1068 |
+
return self._bisect(x, mid + 1, hi)
|
| 1069 |
+
else:
|
| 1070 |
+
return -1
|
| 1071 |
+
|
| 1072 |
+
def __len__(self):
|
| 1073 |
+
return self.num_samples // self.batch_size
|
GPT_SoVITS/module/losses.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def feature_loss(fmap_r, fmap_g):
|
| 7 |
+
loss = torch.tensor(0).to(fmap_r[0][0].device)
|
| 8 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
| 9 |
+
for rl, gl in zip(dr, dg):
|
| 10 |
+
rl = rl.float().detach()
|
| 11 |
+
gl = gl.float()
|
| 12 |
+
loss = torch.mean(torch.abs(rl - gl)) + loss
|
| 13 |
+
|
| 14 |
+
return loss * 2
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
| 18 |
+
loss = torch.tensor(0).to(disc_real_outputs[0].device)
|
| 19 |
+
r_losses = []
|
| 20 |
+
g_losses = []
|
| 21 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
| 22 |
+
dr = dr.float()
|
| 23 |
+
dg = dg.float()
|
| 24 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
| 25 |
+
g_loss = torch.mean(dg**2)
|
| 26 |
+
loss = r_loss + g_loss + loss
|
| 27 |
+
r_losses.append(r_loss.item())
|
| 28 |
+
g_losses.append(g_loss.item())
|
| 29 |
+
|
| 30 |
+
return loss, r_losses, g_losses
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def generator_loss(disc_outputs):
|
| 34 |
+
loss = torch.tensor(0).to(disc_outputs[0].device)
|
| 35 |
+
gen_losses = []
|
| 36 |
+
for dg in disc_outputs:
|
| 37 |
+
dg = dg.float()
|
| 38 |
+
l = torch.mean((1 - dg) ** 2)
|
| 39 |
+
gen_losses.append(l)
|
| 40 |
+
loss = l + loss
|
| 41 |
+
|
| 42 |
+
return loss, gen_losses
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
| 46 |
+
"""
|
| 47 |
+
z_p, logs_q: [b, h, t_t]
|
| 48 |
+
m_p, logs_p: [b, h, t_t]
|
| 49 |
+
"""
|
| 50 |
+
z_p = z_p.float()
|
| 51 |
+
logs_q = logs_q.float()
|
| 52 |
+
m_p = m_p.float()
|
| 53 |
+
logs_p = logs_p.float()
|
| 54 |
+
z_mask = z_mask.float()
|
| 55 |
+
|
| 56 |
+
kl = logs_p - logs_q - 0.5
|
| 57 |
+
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
| 58 |
+
kl = torch.sum(kl * z_mask)
|
| 59 |
+
l = kl / torch.sum(z_mask)
|
| 60 |
+
return l
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def mle_loss(z, m, logs, logdet, mask):
|
| 64 |
+
l = torch.sum(logs) + 0.5 * torch.sum(
|
| 65 |
+
torch.exp(-2 * logs) * ((z - m) ** 2)
|
| 66 |
+
) # neg normal likelihood w/o the constant term
|
| 67 |
+
l = l - torch.sum(logdet) # log jacobian determinant
|
| 68 |
+
l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes
|
| 69 |
+
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
|
| 70 |
+
return l
|
GPT_SoVITS/module/mel_processing.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 3 |
+
|
| 4 |
+
MAX_WAV_VALUE = 32768.0
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 8 |
+
"""
|
| 9 |
+
PARAMS
|
| 10 |
+
------
|
| 11 |
+
C: compression factor
|
| 12 |
+
"""
|
| 13 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def dynamic_range_decompression_torch(x, C=1):
|
| 17 |
+
"""
|
| 18 |
+
PARAMS
|
| 19 |
+
------
|
| 20 |
+
C: compression factor used to compress
|
| 21 |
+
"""
|
| 22 |
+
return torch.exp(x) / C
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def spectral_normalize_torch(magnitudes):
|
| 26 |
+
output = dynamic_range_compression_torch(magnitudes)
|
| 27 |
+
return output
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def spectral_de_normalize_torch(magnitudes):
|
| 31 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
| 32 |
+
return output
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
mel_basis = {}
|
| 36 |
+
hann_window = {}
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
| 40 |
+
if torch.min(y) < -1.2:
|
| 41 |
+
print("min value is ", torch.min(y))
|
| 42 |
+
if torch.max(y) > 1.2:
|
| 43 |
+
print("max value is ", torch.max(y))
|
| 44 |
+
|
| 45 |
+
global hann_window
|
| 46 |
+
dtype_device = str(y.dtype) + "_" + str(y.device)
|
| 47 |
+
# wnsize_dtype_device = str(win_size) + '_' + dtype_device
|
| 48 |
+
key = "%s-%s-%s-%s-%s" % (dtype_device, n_fft, sampling_rate, hop_size, win_size)
|
| 49 |
+
# if wnsize_dtype_device not in hann_window:
|
| 50 |
+
if key not in hann_window:
|
| 51 |
+
# hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
| 52 |
+
hann_window[key] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
| 53 |
+
|
| 54 |
+
y = torch.nn.functional.pad(
|
| 55 |
+
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
| 56 |
+
)
|
| 57 |
+
y = y.squeeze(1)
|
| 58 |
+
# spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
|
| 59 |
+
spec = torch.stft(
|
| 60 |
+
y,
|
| 61 |
+
n_fft,
|
| 62 |
+
hop_length=hop_size,
|
| 63 |
+
win_length=win_size,
|
| 64 |
+
window=hann_window[key],
|
| 65 |
+
center=center,
|
| 66 |
+
pad_mode="reflect",
|
| 67 |
+
normalized=False,
|
| 68 |
+
onesided=True,
|
| 69 |
+
return_complex=True,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
spec = spec.abs().pow_(2).add_(1e-8).sqrt_()
|
| 73 |
+
return spec
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
| 77 |
+
global mel_basis
|
| 78 |
+
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
| 79 |
+
# fmax_dtype_device = str(fmax) + '_' + dtype_device
|
| 80 |
+
key = "%s-%s-%s-%s-%s-%s" % (dtype_device, n_fft, num_mels, sampling_rate, fmin, fmax)
|
| 81 |
+
# if fmax_dtype_device not in mel_basis:
|
| 82 |
+
if key not in mel_basis:
|
| 83 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
| 84 |
+
# mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
| 85 |
+
mel_basis[key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
| 86 |
+
# spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
| 87 |
+
spec = torch.matmul(mel_basis[key], spec)
|
| 88 |
+
spec = spectral_normalize_torch(spec)
|
| 89 |
+
return spec
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
| 93 |
+
if torch.min(y) < -1.2:
|
| 94 |
+
print("min value is ", torch.min(y))
|
| 95 |
+
if torch.max(y) > 1.2:
|
| 96 |
+
print("max value is ", torch.max(y))
|
| 97 |
+
|
| 98 |
+
global mel_basis, hann_window
|
| 99 |
+
dtype_device = str(y.dtype) + "_" + str(y.device)
|
| 100 |
+
# fmax_dtype_device = str(fmax) + '_' + dtype_device
|
| 101 |
+
fmax_dtype_device = "%s-%s-%s-%s-%s-%s-%s-%s" % (
|
| 102 |
+
dtype_device,
|
| 103 |
+
n_fft,
|
| 104 |
+
num_mels,
|
| 105 |
+
sampling_rate,
|
| 106 |
+
hop_size,
|
| 107 |
+
win_size,
|
| 108 |
+
fmin,
|
| 109 |
+
fmax,
|
| 110 |
+
)
|
| 111 |
+
# wnsize_dtype_device = str(win_size) + '_' + dtype_device
|
| 112 |
+
wnsize_dtype_device = fmax_dtype_device
|
| 113 |
+
if fmax_dtype_device not in mel_basis:
|
| 114 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
| 115 |
+
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
|
| 116 |
+
if wnsize_dtype_device not in hann_window:
|
| 117 |
+
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
| 118 |
+
|
| 119 |
+
y = torch.nn.functional.pad(
|
| 120 |
+
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
| 121 |
+
)
|
| 122 |
+
y = y.squeeze(1)
|
| 123 |
+
|
| 124 |
+
spec = torch.stft(
|
| 125 |
+
y,
|
| 126 |
+
n_fft,
|
| 127 |
+
hop_length=hop_size,
|
| 128 |
+
win_length=win_size,
|
| 129 |
+
window=hann_window[wnsize_dtype_device],
|
| 130 |
+
center=center,
|
| 131 |
+
pad_mode="reflect",
|
| 132 |
+
normalized=False,
|
| 133 |
+
onesided=True,
|
| 134 |
+
return_complex=True,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
spec = spec.abs().pow_(2).add_(1e-8).sqrt_()
|
| 138 |
+
|
| 139 |
+
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
| 140 |
+
spec = spectral_normalize_torch(spec)
|
| 141 |
+
|
| 142 |
+
return spec
|
GPT_SoVITS/module/models.py
ADDED
|
@@ -0,0 +1,1411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import math
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.cuda.amp import autocast
|
| 8 |
+
from torch.nn import Conv1d, Conv2d, ConvTranspose1d
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
| 11 |
+
|
| 12 |
+
from GPT_SoVITS.f5_tts.model import DiT
|
| 13 |
+
from GPT_SoVITS.text import symbols as symbols_v1
|
| 14 |
+
from GPT_SoVITS.text import symbols2 as symbols_v2
|
| 15 |
+
from GPT_SoVITS.utils import HParams
|
| 16 |
+
from tools.my_utils import _open_file
|
| 17 |
+
|
| 18 |
+
from . import attentions, commons, modules
|
| 19 |
+
from .commons import get_padding, init_weights
|
| 20 |
+
from .mrte_model import MRTE
|
| 21 |
+
from .quantize import ResidualVectorQuantizer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def set_serialization():
|
| 25 |
+
torch.serialization.add_safe_globals([(HParams, "utils.HParams")])
|
| 26 |
+
torch.serialization._open_file = _open_file
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
set_serialization()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class StochasticDurationPredictor(nn.Module):
|
| 33 |
+
def __init__(
|
| 34 |
+
self,
|
| 35 |
+
in_channels,
|
| 36 |
+
filter_channels,
|
| 37 |
+
kernel_size,
|
| 38 |
+
p_dropout,
|
| 39 |
+
n_flows=4,
|
| 40 |
+
gin_channels=0,
|
| 41 |
+
):
|
| 42 |
+
super().__init__()
|
| 43 |
+
filter_channels = in_channels # it needs to be removed from future version.
|
| 44 |
+
self.in_channels = in_channels
|
| 45 |
+
self.filter_channels = filter_channels
|
| 46 |
+
self.kernel_size = kernel_size
|
| 47 |
+
self.p_dropout = p_dropout
|
| 48 |
+
self.n_flows = n_flows
|
| 49 |
+
self.gin_channels = gin_channels
|
| 50 |
+
|
| 51 |
+
self.log_flow = modules.Log()
|
| 52 |
+
self.flows = nn.ModuleList()
|
| 53 |
+
self.flows.append(modules.ElementwiseAffine(2))
|
| 54 |
+
for i in range(n_flows):
|
| 55 |
+
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
| 56 |
+
self.flows.append(modules.Flip())
|
| 57 |
+
|
| 58 |
+
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
| 59 |
+
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
| 60 |
+
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
| 61 |
+
self.post_flows = nn.ModuleList()
|
| 62 |
+
self.post_flows.append(modules.ElementwiseAffine(2))
|
| 63 |
+
for i in range(4):
|
| 64 |
+
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
| 65 |
+
self.post_flows.append(modules.Flip())
|
| 66 |
+
|
| 67 |
+
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
| 68 |
+
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
| 69 |
+
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
| 70 |
+
if gin_channels != 0:
|
| 71 |
+
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
| 72 |
+
|
| 73 |
+
def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
|
| 74 |
+
x = torch.detach(x)
|
| 75 |
+
x = self.pre(x)
|
| 76 |
+
if g is not None:
|
| 77 |
+
g = torch.detach(g)
|
| 78 |
+
x = x + self.cond(g)
|
| 79 |
+
x = self.convs(x, x_mask)
|
| 80 |
+
x = self.proj(x) * x_mask
|
| 81 |
+
|
| 82 |
+
if not reverse:
|
| 83 |
+
flows = self.flows
|
| 84 |
+
assert w is not None
|
| 85 |
+
|
| 86 |
+
logdet_tot_q = 0
|
| 87 |
+
h_w = self.post_pre(w)
|
| 88 |
+
h_w = self.post_convs(h_w, x_mask)
|
| 89 |
+
h_w = self.post_proj(h_w) * x_mask
|
| 90 |
+
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
| 91 |
+
z_q = e_q
|
| 92 |
+
for flow in self.post_flows:
|
| 93 |
+
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
| 94 |
+
logdet_tot_q += logdet_q
|
| 95 |
+
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
| 96 |
+
u = torch.sigmoid(z_u) * x_mask
|
| 97 |
+
z0 = (w - u) * x_mask
|
| 98 |
+
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
|
| 99 |
+
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
|
| 100 |
+
|
| 101 |
+
logdet_tot = 0
|
| 102 |
+
z0, logdet = self.log_flow(z0, x_mask)
|
| 103 |
+
logdet_tot += logdet
|
| 104 |
+
z = torch.cat([z0, z1], 1)
|
| 105 |
+
for flow in flows:
|
| 106 |
+
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
| 107 |
+
logdet_tot = logdet_tot + logdet
|
| 108 |
+
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
|
| 109 |
+
return nll + logq # [b]
|
| 110 |
+
else:
|
| 111 |
+
flows = list(reversed(self.flows))
|
| 112 |
+
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
| 113 |
+
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
| 114 |
+
for flow in flows:
|
| 115 |
+
z = flow(z, x_mask, g=x, reverse=reverse)
|
| 116 |
+
z0, z1 = torch.split(z, [1, 1], 1)
|
| 117 |
+
logw = z0
|
| 118 |
+
return logw
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class DurationPredictor(nn.Module):
|
| 122 |
+
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
|
| 123 |
+
super().__init__()
|
| 124 |
+
|
| 125 |
+
self.in_channels = in_channels
|
| 126 |
+
self.filter_channels = filter_channels
|
| 127 |
+
self.kernel_size = kernel_size
|
| 128 |
+
self.p_dropout = p_dropout
|
| 129 |
+
self.gin_channels = gin_channels
|
| 130 |
+
|
| 131 |
+
self.drop = nn.Dropout(p_dropout)
|
| 132 |
+
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
| 133 |
+
self.norm_1 = modules.LayerNorm(filter_channels)
|
| 134 |
+
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
| 135 |
+
self.norm_2 = modules.LayerNorm(filter_channels)
|
| 136 |
+
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
| 137 |
+
|
| 138 |
+
if gin_channels != 0:
|
| 139 |
+
self.cond = nn.Conv1d(gin_channels, in_channels, 1)
|
| 140 |
+
|
| 141 |
+
def forward(self, x, x_mask, g=None):
|
| 142 |
+
x = torch.detach(x)
|
| 143 |
+
if g is not None:
|
| 144 |
+
g = torch.detach(g)
|
| 145 |
+
x = x + self.cond(g)
|
| 146 |
+
x = self.conv_1(x * x_mask)
|
| 147 |
+
x = torch.relu(x)
|
| 148 |
+
x = self.norm_1(x)
|
| 149 |
+
x = self.drop(x)
|
| 150 |
+
x = self.conv_2(x * x_mask)
|
| 151 |
+
x = torch.relu(x)
|
| 152 |
+
x = self.norm_2(x)
|
| 153 |
+
x = self.drop(x)
|
| 154 |
+
x = self.proj(x * x_mask)
|
| 155 |
+
return x * x_mask
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class TextEncoder(nn.Module):
|
| 159 |
+
def __init__(
|
| 160 |
+
self,
|
| 161 |
+
out_channels,
|
| 162 |
+
hidden_channels,
|
| 163 |
+
filter_channels,
|
| 164 |
+
n_heads,
|
| 165 |
+
n_layers,
|
| 166 |
+
kernel_size,
|
| 167 |
+
p_dropout,
|
| 168 |
+
latent_channels=192,
|
| 169 |
+
version="v2",
|
| 170 |
+
):
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.out_channels = out_channels
|
| 173 |
+
self.hidden_channels = hidden_channels
|
| 174 |
+
self.filter_channels = filter_channels
|
| 175 |
+
self.n_heads = n_heads
|
| 176 |
+
self.n_layers = n_layers
|
| 177 |
+
self.kernel_size = kernel_size
|
| 178 |
+
self.p_dropout = p_dropout
|
| 179 |
+
self.latent_channels = latent_channels
|
| 180 |
+
self.version = version
|
| 181 |
+
|
| 182 |
+
self.ssl_proj = nn.Conv1d(768, hidden_channels, 1)
|
| 183 |
+
|
| 184 |
+
self.encoder_ssl = attentions.Encoder(
|
| 185 |
+
hidden_channels,
|
| 186 |
+
filter_channels,
|
| 187 |
+
n_heads,
|
| 188 |
+
n_layers // 2,
|
| 189 |
+
kernel_size,
|
| 190 |
+
p_dropout,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
self.encoder_text = attentions.Encoder(
|
| 194 |
+
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
if self.version == "v1":
|
| 198 |
+
symbols = symbols_v1.symbols
|
| 199 |
+
else:
|
| 200 |
+
symbols = symbols_v2.symbols
|
| 201 |
+
self.text_embedding = nn.Embedding(len(symbols), hidden_channels)
|
| 202 |
+
|
| 203 |
+
self.mrte = MRTE()
|
| 204 |
+
|
| 205 |
+
self.encoder2 = attentions.Encoder(
|
| 206 |
+
hidden_channels,
|
| 207 |
+
filter_channels,
|
| 208 |
+
n_heads,
|
| 209 |
+
n_layers // 2,
|
| 210 |
+
kernel_size,
|
| 211 |
+
p_dropout,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 215 |
+
|
| 216 |
+
def forward(self, y, y_lengths, text, text_lengths, ge, speed=1, test=None):
|
| 217 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
| 218 |
+
|
| 219 |
+
y = self.ssl_proj(y * y_mask) * y_mask
|
| 220 |
+
|
| 221 |
+
y = self.encoder_ssl(y * y_mask, y_mask)
|
| 222 |
+
|
| 223 |
+
text_mask = torch.unsqueeze(commons.sequence_mask(text_lengths, text.size(1)), 1).to(y.dtype)
|
| 224 |
+
if test == 1:
|
| 225 |
+
text[:, :] = 0
|
| 226 |
+
text = self.text_embedding(text).transpose(1, 2)
|
| 227 |
+
text = self.encoder_text(text * text_mask, text_mask)
|
| 228 |
+
y = self.mrte(y, y_mask, text, text_mask, ge)
|
| 229 |
+
y = self.encoder2(y * y_mask, y_mask)
|
| 230 |
+
if speed != 1:
|
| 231 |
+
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
|
| 232 |
+
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
| 233 |
+
stats = self.proj(y) * y_mask
|
| 234 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
| 235 |
+
return y, m, logs, y_mask
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
class ResidualCouplingBlock(nn.Module):
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
channels,
|
| 242 |
+
hidden_channels,
|
| 243 |
+
kernel_size,
|
| 244 |
+
dilation_rate,
|
| 245 |
+
n_layers,
|
| 246 |
+
n_flows=4,
|
| 247 |
+
gin_channels=0,
|
| 248 |
+
):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.channels = channels
|
| 251 |
+
self.hidden_channels = hidden_channels
|
| 252 |
+
self.kernel_size = kernel_size
|
| 253 |
+
self.dilation_rate = dilation_rate
|
| 254 |
+
self.n_layers = n_layers
|
| 255 |
+
self.n_flows = n_flows
|
| 256 |
+
self.gin_channels = gin_channels
|
| 257 |
+
|
| 258 |
+
self.flows = nn.ModuleList()
|
| 259 |
+
for i in range(n_flows):
|
| 260 |
+
self.flows.append(
|
| 261 |
+
modules.ResidualCouplingLayer(
|
| 262 |
+
channels,
|
| 263 |
+
hidden_channels,
|
| 264 |
+
kernel_size,
|
| 265 |
+
dilation_rate,
|
| 266 |
+
n_layers,
|
| 267 |
+
gin_channels=gin_channels,
|
| 268 |
+
mean_only=True,
|
| 269 |
+
)
|
| 270 |
+
)
|
| 271 |
+
self.flows.append(modules.Flip())
|
| 272 |
+
|
| 273 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
| 274 |
+
if not reverse:
|
| 275 |
+
for flow in self.flows:
|
| 276 |
+
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
| 277 |
+
else:
|
| 278 |
+
for flow in reversed(self.flows):
|
| 279 |
+
x = flow(x, x_mask, g=g, reverse=reverse)
|
| 280 |
+
return x
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class PosteriorEncoder(nn.Module):
|
| 284 |
+
def __init__(
|
| 285 |
+
self,
|
| 286 |
+
in_channels,
|
| 287 |
+
out_channels,
|
| 288 |
+
hidden_channels,
|
| 289 |
+
kernel_size,
|
| 290 |
+
dilation_rate,
|
| 291 |
+
n_layers,
|
| 292 |
+
gin_channels=0,
|
| 293 |
+
):
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.in_channels = in_channels
|
| 296 |
+
self.out_channels = out_channels
|
| 297 |
+
self.hidden_channels = hidden_channels
|
| 298 |
+
self.kernel_size = kernel_size
|
| 299 |
+
self.dilation_rate = dilation_rate
|
| 300 |
+
self.n_layers = n_layers
|
| 301 |
+
self.gin_channels = gin_channels
|
| 302 |
+
|
| 303 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
| 304 |
+
self.enc = modules.WN(
|
| 305 |
+
hidden_channels,
|
| 306 |
+
kernel_size,
|
| 307 |
+
dilation_rate,
|
| 308 |
+
n_layers,
|
| 309 |
+
gin_channels=gin_channels,
|
| 310 |
+
)
|
| 311 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
| 312 |
+
|
| 313 |
+
def forward(self, x, x_lengths, g=None):
|
| 314 |
+
if g != None:
|
| 315 |
+
g = g.detach()
|
| 316 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
| 317 |
+
x = self.pre(x) * x_mask
|
| 318 |
+
x = self.enc(x, x_mask, g=g)
|
| 319 |
+
stats = self.proj(x) * x_mask
|
| 320 |
+
m, logs = torch.split(stats, self.out_channels, dim=1)
|
| 321 |
+
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
| 322 |
+
return z, m, logs, x_mask
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class Encoder(nn.Module):
|
| 326 |
+
def __init__(
|
| 327 |
+
self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
|
| 328 |
+
):
|
| 329 |
+
super().__init__()
|
| 330 |
+
self.in_channels = in_channels
|
| 331 |
+
self.out_channels = out_channels
|
| 332 |
+
self.hidden_channels = hidden_channels
|
| 333 |
+
self.kernel_size = kernel_size
|
| 334 |
+
self.dilation_rate = dilation_rate
|
| 335 |
+
self.n_layers = n_layers
|
| 336 |
+
self.gin_channels = gin_channels
|
| 337 |
+
|
| 338 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
| 339 |
+
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
| 340 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
| 341 |
+
|
| 342 |
+
def forward(self, x, x_lengths, g=None):
|
| 343 |
+
if g != None:
|
| 344 |
+
g = g.detach()
|
| 345 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
| 346 |
+
x = self.pre(x) * x_mask
|
| 347 |
+
x = self.enc(x, x_mask, g=g)
|
| 348 |
+
stats = self.proj(x) * x_mask
|
| 349 |
+
return stats, x_mask
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class WNEncoder(nn.Module):
|
| 353 |
+
def __init__(
|
| 354 |
+
self,
|
| 355 |
+
in_channels,
|
| 356 |
+
out_channels,
|
| 357 |
+
hidden_channels,
|
| 358 |
+
kernel_size,
|
| 359 |
+
dilation_rate,
|
| 360 |
+
n_layers,
|
| 361 |
+
gin_channels=0,
|
| 362 |
+
):
|
| 363 |
+
super().__init__()
|
| 364 |
+
self.in_channels = in_channels
|
| 365 |
+
self.out_channels = out_channels
|
| 366 |
+
self.hidden_channels = hidden_channels
|
| 367 |
+
self.kernel_size = kernel_size
|
| 368 |
+
self.dilation_rate = dilation_rate
|
| 369 |
+
self.n_layers = n_layers
|
| 370 |
+
self.gin_channels = gin_channels
|
| 371 |
+
|
| 372 |
+
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
| 373 |
+
self.enc = modules.WN(
|
| 374 |
+
hidden_channels,
|
| 375 |
+
kernel_size,
|
| 376 |
+
dilation_rate,
|
| 377 |
+
n_layers,
|
| 378 |
+
gin_channels=gin_channels,
|
| 379 |
+
)
|
| 380 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
| 381 |
+
self.norm = modules.LayerNorm(out_channels)
|
| 382 |
+
|
| 383 |
+
def forward(self, x, x_lengths, g=None):
|
| 384 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
| 385 |
+
x = self.pre(x) * x_mask
|
| 386 |
+
x = self.enc(x, x_mask, g=g)
|
| 387 |
+
out = self.proj(x) * x_mask
|
| 388 |
+
out = self.norm(out)
|
| 389 |
+
return out
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class Generator(torch.nn.Module):
|
| 393 |
+
def __init__(
|
| 394 |
+
self,
|
| 395 |
+
initial_channel,
|
| 396 |
+
resblock,
|
| 397 |
+
resblock_kernel_sizes,
|
| 398 |
+
resblock_dilation_sizes,
|
| 399 |
+
upsample_rates,
|
| 400 |
+
upsample_initial_channel,
|
| 401 |
+
upsample_kernel_sizes,
|
| 402 |
+
gin_channels=0,
|
| 403 |
+
is_bias=False,
|
| 404 |
+
):
|
| 405 |
+
super(Generator, self).__init__()
|
| 406 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
| 407 |
+
self.num_upsamples = len(upsample_rates)
|
| 408 |
+
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
| 409 |
+
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
| 410 |
+
|
| 411 |
+
self.ups = nn.ModuleList()
|
| 412 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
| 413 |
+
self.ups.append(
|
| 414 |
+
weight_norm(
|
| 415 |
+
ConvTranspose1d(
|
| 416 |
+
upsample_initial_channel // (2**i),
|
| 417 |
+
upsample_initial_channel // (2 ** (i + 1)),
|
| 418 |
+
k,
|
| 419 |
+
u,
|
| 420 |
+
padding=(k - u) // 2,
|
| 421 |
+
)
|
| 422 |
+
)
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
self.resblocks = nn.ModuleList()
|
| 426 |
+
for i in range(len(self.ups)):
|
| 427 |
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
| 428 |
+
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
| 429 |
+
self.resblocks.append(resblock(ch, k, d))
|
| 430 |
+
|
| 431 |
+
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=is_bias)
|
| 432 |
+
self.ups.apply(init_weights)
|
| 433 |
+
|
| 434 |
+
if gin_channels != 0:
|
| 435 |
+
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
| 436 |
+
|
| 437 |
+
def forward(self, x, g=None):
|
| 438 |
+
x = self.conv_pre(x)
|
| 439 |
+
if g is not None:
|
| 440 |
+
x = x + self.cond(g)
|
| 441 |
+
|
| 442 |
+
for i in range(self.num_upsamples):
|
| 443 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 444 |
+
x = self.ups[i](x)
|
| 445 |
+
xs = None
|
| 446 |
+
for j in range(self.num_kernels):
|
| 447 |
+
if xs is None:
|
| 448 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 449 |
+
else:
|
| 450 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 451 |
+
x = xs / self.num_kernels
|
| 452 |
+
x = F.leaky_relu(x)
|
| 453 |
+
x = self.conv_post(x)
|
| 454 |
+
x = torch.tanh(x)
|
| 455 |
+
|
| 456 |
+
return x
|
| 457 |
+
|
| 458 |
+
def remove_weight_norm(self):
|
| 459 |
+
print("Removing weight norm...")
|
| 460 |
+
for l in self.ups:
|
| 461 |
+
remove_weight_norm(l)
|
| 462 |
+
for l in self.resblocks:
|
| 463 |
+
l.remove_weight_norm()
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
class DiscriminatorP(torch.nn.Module):
|
| 467 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
| 468 |
+
super(DiscriminatorP, self).__init__()
|
| 469 |
+
self.period = period
|
| 470 |
+
self.use_spectral_norm = use_spectral_norm
|
| 471 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
| 472 |
+
self.convs = nn.ModuleList(
|
| 473 |
+
[
|
| 474 |
+
norm_f(
|
| 475 |
+
Conv2d(
|
| 476 |
+
1,
|
| 477 |
+
32,
|
| 478 |
+
(kernel_size, 1),
|
| 479 |
+
(stride, 1),
|
| 480 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 481 |
+
)
|
| 482 |
+
),
|
| 483 |
+
norm_f(
|
| 484 |
+
Conv2d(
|
| 485 |
+
32,
|
| 486 |
+
128,
|
| 487 |
+
(kernel_size, 1),
|
| 488 |
+
(stride, 1),
|
| 489 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 490 |
+
)
|
| 491 |
+
),
|
| 492 |
+
norm_f(
|
| 493 |
+
Conv2d(
|
| 494 |
+
128,
|
| 495 |
+
512,
|
| 496 |
+
(kernel_size, 1),
|
| 497 |
+
(stride, 1),
|
| 498 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 499 |
+
)
|
| 500 |
+
),
|
| 501 |
+
norm_f(
|
| 502 |
+
Conv2d(
|
| 503 |
+
512,
|
| 504 |
+
1024,
|
| 505 |
+
(kernel_size, 1),
|
| 506 |
+
(stride, 1),
|
| 507 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 508 |
+
)
|
| 509 |
+
),
|
| 510 |
+
norm_f(
|
| 511 |
+
Conv2d(
|
| 512 |
+
1024,
|
| 513 |
+
1024,
|
| 514 |
+
(kernel_size, 1),
|
| 515 |
+
1,
|
| 516 |
+
padding=(get_padding(kernel_size, 1), 0),
|
| 517 |
+
)
|
| 518 |
+
),
|
| 519 |
+
]
|
| 520 |
+
)
|
| 521 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 522 |
+
|
| 523 |
+
def forward(self, x):
|
| 524 |
+
fmap = []
|
| 525 |
+
|
| 526 |
+
# 1d to 2d
|
| 527 |
+
b, c, t = x.shape
|
| 528 |
+
if t % self.period != 0: # pad first
|
| 529 |
+
n_pad = self.period - (t % self.period)
|
| 530 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
| 531 |
+
t = t + n_pad
|
| 532 |
+
x = x.view(b, c, t // self.period, self.period)
|
| 533 |
+
|
| 534 |
+
for l in self.convs:
|
| 535 |
+
x = l(x)
|
| 536 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 537 |
+
fmap.append(x)
|
| 538 |
+
x = self.conv_post(x)
|
| 539 |
+
fmap.append(x)
|
| 540 |
+
x = torch.flatten(x, 1, -1)
|
| 541 |
+
|
| 542 |
+
return x, fmap
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
class DiscriminatorS(torch.nn.Module):
|
| 546 |
+
def __init__(self, use_spectral_norm=False):
|
| 547 |
+
super(DiscriminatorS, self).__init__()
|
| 548 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
| 549 |
+
self.convs = nn.ModuleList(
|
| 550 |
+
[
|
| 551 |
+
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
|
| 552 |
+
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
|
| 553 |
+
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
|
| 554 |
+
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
|
| 555 |
+
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
|
| 556 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
| 557 |
+
]
|
| 558 |
+
)
|
| 559 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
| 560 |
+
|
| 561 |
+
def forward(self, x):
|
| 562 |
+
fmap = []
|
| 563 |
+
|
| 564 |
+
for l in self.convs:
|
| 565 |
+
x = l(x)
|
| 566 |
+
x = F.leaky_relu(x, modules.LRELU_SLOPE)
|
| 567 |
+
fmap.append(x)
|
| 568 |
+
x = self.conv_post(x)
|
| 569 |
+
fmap.append(x)
|
| 570 |
+
x = torch.flatten(x, 1, -1)
|
| 571 |
+
|
| 572 |
+
return x, fmap
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
v2pro_set = {"v2Pro", "v2ProPlus"}
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
| 579 |
+
def __init__(self, use_spectral_norm=False, version=None):
|
| 580 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
| 581 |
+
if version in v2pro_set:
|
| 582 |
+
periods = [2, 3, 5, 7, 11, 17, 23]
|
| 583 |
+
else:
|
| 584 |
+
periods = [2, 3, 5, 7, 11]
|
| 585 |
+
|
| 586 |
+
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
| 587 |
+
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
| 588 |
+
self.discriminators = nn.ModuleList(discs)
|
| 589 |
+
|
| 590 |
+
def forward(self, y, y_hat):
|
| 591 |
+
y_d_rs = []
|
| 592 |
+
y_d_gs = []
|
| 593 |
+
fmap_rs = []
|
| 594 |
+
fmap_gs = []
|
| 595 |
+
for i, d in enumerate(self.discriminators):
|
| 596 |
+
y_d_r, fmap_r = d(y)
|
| 597 |
+
y_d_g, fmap_g = d(y_hat)
|
| 598 |
+
y_d_rs.append(y_d_r)
|
| 599 |
+
y_d_gs.append(y_d_g)
|
| 600 |
+
fmap_rs.append(fmap_r)
|
| 601 |
+
fmap_gs.append(fmap_g)
|
| 602 |
+
|
| 603 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
class ReferenceEncoder(nn.Module):
|
| 607 |
+
"""
|
| 608 |
+
inputs --- [N, Ty/r, n_mels*r] mels
|
| 609 |
+
outputs --- [N, ref_enc_gru_size]
|
| 610 |
+
"""
|
| 611 |
+
|
| 612 |
+
def __init__(self, spec_channels, gin_channels=0):
|
| 613 |
+
super().__init__()
|
| 614 |
+
self.spec_channels = spec_channels
|
| 615 |
+
ref_enc_filters = [32, 32, 64, 64, 128, 128]
|
| 616 |
+
K = len(ref_enc_filters)
|
| 617 |
+
filters = [1] + ref_enc_filters
|
| 618 |
+
convs = [
|
| 619 |
+
weight_norm(
|
| 620 |
+
nn.Conv2d(
|
| 621 |
+
in_channels=filters[i],
|
| 622 |
+
out_channels=filters[i + 1],
|
| 623 |
+
kernel_size=(3, 3),
|
| 624 |
+
stride=(2, 2),
|
| 625 |
+
padding=(1, 1),
|
| 626 |
+
)
|
| 627 |
+
)
|
| 628 |
+
for i in range(K)
|
| 629 |
+
]
|
| 630 |
+
self.convs = nn.ModuleList(convs)
|
| 631 |
+
# self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)])
|
| 632 |
+
|
| 633 |
+
out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
|
| 634 |
+
self.gru = nn.GRU(
|
| 635 |
+
input_size=ref_enc_filters[-1] * out_channels,
|
| 636 |
+
hidden_size=256 // 2,
|
| 637 |
+
batch_first=True,
|
| 638 |
+
)
|
| 639 |
+
self.proj = nn.Linear(128, gin_channels)
|
| 640 |
+
|
| 641 |
+
def forward(self, inputs):
|
| 642 |
+
N = inputs.size(0)
|
| 643 |
+
out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
|
| 644 |
+
for conv in self.convs:
|
| 645 |
+
out = conv(out)
|
| 646 |
+
# out = wn(out)
|
| 647 |
+
out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
|
| 648 |
+
|
| 649 |
+
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
| 650 |
+
T = out.size(1)
|
| 651 |
+
N = out.size(0)
|
| 652 |
+
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
| 653 |
+
|
| 654 |
+
self.gru.flatten_parameters()
|
| 655 |
+
memory, out = self.gru(out) # out --- [1, N, 128]
|
| 656 |
+
|
| 657 |
+
return self.proj(out.squeeze(0)).unsqueeze(-1)
|
| 658 |
+
|
| 659 |
+
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
|
| 660 |
+
for i in range(n_convs):
|
| 661 |
+
L = (L - kernel_size + 2 * pad) // stride + 1
|
| 662 |
+
return L
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
class Quantizer_module(torch.nn.Module):
|
| 666 |
+
def __init__(self, n_e, e_dim):
|
| 667 |
+
super(Quantizer_module, self).__init__()
|
| 668 |
+
self.embedding = nn.Embedding(n_e, e_dim)
|
| 669 |
+
self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
|
| 670 |
+
|
| 671 |
+
def forward(self, x):
|
| 672 |
+
d = (
|
| 673 |
+
torch.sum(x**2, 1, keepdim=True)
|
| 674 |
+
+ torch.sum(self.embedding.weight**2, 1)
|
| 675 |
+
- 2 * torch.matmul(x, self.embedding.weight.T)
|
| 676 |
+
)
|
| 677 |
+
min_indicies = torch.argmin(d, 1)
|
| 678 |
+
z_q = self.embedding(min_indicies)
|
| 679 |
+
return z_q, min_indicies
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
class Quantizer(torch.nn.Module):
|
| 683 |
+
def __init__(self, embed_dim=512, n_code_groups=4, n_codes=160):
|
| 684 |
+
super(Quantizer, self).__init__()
|
| 685 |
+
assert embed_dim % n_code_groups == 0
|
| 686 |
+
self.quantizer_modules = nn.ModuleList(
|
| 687 |
+
[Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)]
|
| 688 |
+
)
|
| 689 |
+
self.n_code_groups = n_code_groups
|
| 690 |
+
self.embed_dim = embed_dim
|
| 691 |
+
|
| 692 |
+
def forward(self, xin):
|
| 693 |
+
# B, C, T
|
| 694 |
+
B, C, T = xin.shape
|
| 695 |
+
xin = xin.transpose(1, 2)
|
| 696 |
+
x = xin.reshape(-1, self.embed_dim)
|
| 697 |
+
x = torch.split(x, self.embed_dim // self.n_code_groups, dim=-1)
|
| 698 |
+
min_indicies = []
|
| 699 |
+
z_q = []
|
| 700 |
+
for _x, m in zip(x, self.quantizer_modules):
|
| 701 |
+
_z_q, _min_indicies = m(_x)
|
| 702 |
+
z_q.append(_z_q)
|
| 703 |
+
min_indicies.append(_min_indicies) # B * T,
|
| 704 |
+
z_q = torch.cat(z_q, -1).reshape(xin.shape)
|
| 705 |
+
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
|
| 706 |
+
z_q = xin + (z_q - xin).detach()
|
| 707 |
+
z_q = z_q.transpose(1, 2)
|
| 708 |
+
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
|
| 709 |
+
return z_q, loss, codes.transpose(1, 2)
|
| 710 |
+
|
| 711 |
+
def embed(self, x):
|
| 712 |
+
# idx: N, 4, T
|
| 713 |
+
x = x.transpose(1, 2)
|
| 714 |
+
x = torch.split(x, 1, 2)
|
| 715 |
+
ret = []
|
| 716 |
+
for q, embed in zip(x, self.quantizer_modules):
|
| 717 |
+
q = embed.embedding(q.squeeze(-1))
|
| 718 |
+
ret.append(q)
|
| 719 |
+
ret = torch.cat(ret, -1)
|
| 720 |
+
return ret.transpose(1, 2) # N, C, T
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
class CodePredictor(nn.Module):
|
| 724 |
+
def __init__(
|
| 725 |
+
self,
|
| 726 |
+
hidden_channels,
|
| 727 |
+
filter_channels,
|
| 728 |
+
n_heads,
|
| 729 |
+
n_layers,
|
| 730 |
+
kernel_size,
|
| 731 |
+
p_dropout,
|
| 732 |
+
n_q=8,
|
| 733 |
+
dims=1024,
|
| 734 |
+
ssl_dim=768,
|
| 735 |
+
):
|
| 736 |
+
super().__init__()
|
| 737 |
+
self.hidden_channels = hidden_channels
|
| 738 |
+
self.filter_channels = filter_channels
|
| 739 |
+
self.n_heads = n_heads
|
| 740 |
+
self.n_layers = n_layers
|
| 741 |
+
self.kernel_size = kernel_size
|
| 742 |
+
self.p_dropout = p_dropout
|
| 743 |
+
|
| 744 |
+
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
|
| 745 |
+
self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels)
|
| 746 |
+
|
| 747 |
+
self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
|
| 748 |
+
|
| 749 |
+
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
|
| 750 |
+
self.n_q = n_q
|
| 751 |
+
self.dims = dims
|
| 752 |
+
|
| 753 |
+
def forward(self, x, x_mask, refer, codes, infer=False):
|
| 754 |
+
x = x.detach()
|
| 755 |
+
x = self.vq_proj(x * x_mask) * x_mask
|
| 756 |
+
g = self.ref_enc(refer, x_mask)
|
| 757 |
+
x = x + g
|
| 758 |
+
x = self.encoder(x * x_mask, x_mask)
|
| 759 |
+
x = self.out_proj(x * x_mask) * x_mask
|
| 760 |
+
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3)
|
| 761 |
+
target = codes[1:].transpose(0, 1)
|
| 762 |
+
if not infer:
|
| 763 |
+
logits = logits.reshape(-1, self.dims)
|
| 764 |
+
target = target.reshape(-1)
|
| 765 |
+
loss = torch.nn.functional.cross_entropy(logits, target)
|
| 766 |
+
return loss
|
| 767 |
+
else:
|
| 768 |
+
_, top10_preds = torch.topk(logits, 10, dim=-1)
|
| 769 |
+
correct_top10 = torch.any(top10_preds == target.unsqueeze(-1), dim=-1)
|
| 770 |
+
top3_acc = 100 * torch.mean(correct_top10.float()).detach().cpu().item()
|
| 771 |
+
|
| 772 |
+
print("Top-10 Accuracy:", top3_acc, "%")
|
| 773 |
+
|
| 774 |
+
pred_codes = torch.argmax(logits, dim=-1)
|
| 775 |
+
acc = 100 * torch.mean((pred_codes == target).float()).detach().cpu().item()
|
| 776 |
+
print("Top-1 Accuracy:", acc, "%")
|
| 777 |
+
|
| 778 |
+
return pred_codes.transpose(0, 1)
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
class SynthesizerTrn(nn.Module):
|
| 782 |
+
"""
|
| 783 |
+
Synthesizer for Training
|
| 784 |
+
"""
|
| 785 |
+
|
| 786 |
+
def __init__(
|
| 787 |
+
self,
|
| 788 |
+
spec_channels,
|
| 789 |
+
segment_size,
|
| 790 |
+
inter_channels,
|
| 791 |
+
hidden_channels,
|
| 792 |
+
filter_channels,
|
| 793 |
+
n_heads,
|
| 794 |
+
n_layers,
|
| 795 |
+
kernel_size,
|
| 796 |
+
p_dropout,
|
| 797 |
+
resblock,
|
| 798 |
+
resblock_kernel_sizes,
|
| 799 |
+
resblock_dilation_sizes,
|
| 800 |
+
upsample_rates,
|
| 801 |
+
upsample_initial_channel,
|
| 802 |
+
upsample_kernel_sizes,
|
| 803 |
+
n_speakers=0,
|
| 804 |
+
gin_channels=0,
|
| 805 |
+
use_sdp=True,
|
| 806 |
+
semantic_frame_rate=None,
|
| 807 |
+
freeze_quantizer=None,
|
| 808 |
+
version="v2",
|
| 809 |
+
**kwargs,
|
| 810 |
+
):
|
| 811 |
+
super().__init__()
|
| 812 |
+
self.spec_channels = spec_channels
|
| 813 |
+
self.inter_channels = inter_channels
|
| 814 |
+
self.hidden_channels = hidden_channels
|
| 815 |
+
self.filter_channels = filter_channels
|
| 816 |
+
self.n_heads = n_heads
|
| 817 |
+
self.n_layers = n_layers
|
| 818 |
+
self.kernel_size = kernel_size
|
| 819 |
+
self.p_dropout = p_dropout
|
| 820 |
+
self.resblock = resblock
|
| 821 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
| 822 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
| 823 |
+
self.upsample_rates = upsample_rates
|
| 824 |
+
self.upsample_initial_channel = upsample_initial_channel
|
| 825 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 826 |
+
self.segment_size = segment_size
|
| 827 |
+
self.n_speakers = n_speakers
|
| 828 |
+
self.gin_channels = gin_channels
|
| 829 |
+
self.version = version
|
| 830 |
+
|
| 831 |
+
self.use_sdp = use_sdp
|
| 832 |
+
self.enc_p = TextEncoder(
|
| 833 |
+
inter_channels,
|
| 834 |
+
hidden_channels,
|
| 835 |
+
filter_channels,
|
| 836 |
+
n_heads,
|
| 837 |
+
n_layers,
|
| 838 |
+
kernel_size,
|
| 839 |
+
p_dropout,
|
| 840 |
+
version=version,
|
| 841 |
+
)
|
| 842 |
+
self.dec = Generator(
|
| 843 |
+
inter_channels,
|
| 844 |
+
resblock,
|
| 845 |
+
resblock_kernel_sizes,
|
| 846 |
+
resblock_dilation_sizes,
|
| 847 |
+
upsample_rates,
|
| 848 |
+
upsample_initial_channel,
|
| 849 |
+
upsample_kernel_sizes,
|
| 850 |
+
gin_channels=gin_channels,
|
| 851 |
+
)
|
| 852 |
+
self.enc_q = PosteriorEncoder(
|
| 853 |
+
spec_channels,
|
| 854 |
+
inter_channels,
|
| 855 |
+
hidden_channels,
|
| 856 |
+
5,
|
| 857 |
+
1,
|
| 858 |
+
16,
|
| 859 |
+
gin_channels=gin_channels,
|
| 860 |
+
)
|
| 861 |
+
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
| 862 |
+
|
| 863 |
+
# self.version=os.environ.get("version","v1")
|
| 864 |
+
if self.version == "v1":
|
| 865 |
+
self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)
|
| 866 |
+
else:
|
| 867 |
+
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
|
| 868 |
+
|
| 869 |
+
ssl_dim = 768
|
| 870 |
+
assert semantic_frame_rate in ["25hz", "50hz"]
|
| 871 |
+
self.semantic_frame_rate = semantic_frame_rate
|
| 872 |
+
if semantic_frame_rate == "25hz":
|
| 873 |
+
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
| 874 |
+
else:
|
| 875 |
+
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
| 876 |
+
|
| 877 |
+
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
| 878 |
+
self.freeze_quantizer = freeze_quantizer
|
| 879 |
+
|
| 880 |
+
self.is_v2pro = self.version in v2pro_set
|
| 881 |
+
if self.is_v2pro:
|
| 882 |
+
self.sv_emb = nn.Linear(20480, gin_channels)
|
| 883 |
+
self.ge_to512 = nn.Linear(gin_channels, 512)
|
| 884 |
+
self.prelu = nn.PReLU(num_parameters=gin_channels)
|
| 885 |
+
|
| 886 |
+
def forward(self, ssl, y, y_lengths, text, text_lengths, sv_emb=None):
|
| 887 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
| 888 |
+
if self.version == "v1":
|
| 889 |
+
ge = self.ref_enc(y * y_mask, y_mask)
|
| 890 |
+
else:
|
| 891 |
+
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
| 892 |
+
if self.is_v2pro:
|
| 893 |
+
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
|
| 894 |
+
ge += sv_emb.unsqueeze(-1)
|
| 895 |
+
ge = self.prelu(ge)
|
| 896 |
+
ge512 = self.ge_to512(ge.transpose(2, 1)).transpose(2, 1)
|
| 897 |
+
with autocast(enabled=False):
|
| 898 |
+
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
|
| 899 |
+
with maybe_no_grad:
|
| 900 |
+
if self.freeze_quantizer:
|
| 901 |
+
self.ssl_proj.eval()
|
| 902 |
+
self.quantizer.eval()
|
| 903 |
+
ssl = self.ssl_proj(ssl)
|
| 904 |
+
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
| 905 |
+
|
| 906 |
+
if self.semantic_frame_rate == "25hz":
|
| 907 |
+
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
| 908 |
+
|
| 909 |
+
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge)
|
| 910 |
+
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
| 911 |
+
z_p = self.flow(z, y_mask, g=ge)
|
| 912 |
+
|
| 913 |
+
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
|
| 914 |
+
o = self.dec(z_slice, g=ge)
|
| 915 |
+
return (
|
| 916 |
+
o,
|
| 917 |
+
commit_loss,
|
| 918 |
+
ids_slice,
|
| 919 |
+
y_mask,
|
| 920 |
+
y_mask,
|
| 921 |
+
(z, z_p, m_p, logs_p, m_q, logs_q),
|
| 922 |
+
quantized,
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
def infer(self, ssl, y, y_lengths, text, text_lengths, test=None, noise_scale=0.5):
|
| 926 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
| 927 |
+
if self.version == "v1":
|
| 928 |
+
ge = self.ref_enc(y * y_mask, y_mask)
|
| 929 |
+
else:
|
| 930 |
+
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
| 931 |
+
|
| 932 |
+
ssl = self.ssl_proj(ssl)
|
| 933 |
+
quantized, codes, commit_loss, _ = self.quantizer(ssl, layers=[0])
|
| 934 |
+
if self.semantic_frame_rate == "25hz":
|
| 935 |
+
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
| 936 |
+
|
| 937 |
+
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, test=test)
|
| 938 |
+
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
| 939 |
+
|
| 940 |
+
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
| 941 |
+
|
| 942 |
+
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
| 943 |
+
return o, y_mask, (z, z_p, m_p, logs_p)
|
| 944 |
+
|
| 945 |
+
def decode(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
|
| 946 |
+
def get_ge(refer, sv_emb):
|
| 947 |
+
ge = None
|
| 948 |
+
if refer is not None:
|
| 949 |
+
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
| 950 |
+
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
| 951 |
+
if self.version == "v1":
|
| 952 |
+
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
| 953 |
+
else:
|
| 954 |
+
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
| 955 |
+
if self.is_v2pro:
|
| 956 |
+
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
|
| 957 |
+
ge += sv_emb.unsqueeze(-1)
|
| 958 |
+
ge = self.prelu(ge)
|
| 959 |
+
return ge
|
| 960 |
+
|
| 961 |
+
if type(refer) == list:
|
| 962 |
+
ges = []
|
| 963 |
+
for idx, _refer in enumerate(refer):
|
| 964 |
+
ge = get_ge(_refer, sv_emb[idx] if self.is_v2pro else None)
|
| 965 |
+
ges.append(ge)
|
| 966 |
+
ge = torch.stack(ges, 0).mean(0)
|
| 967 |
+
else:
|
| 968 |
+
ge = get_ge(refer, sv_emb)
|
| 969 |
+
|
| 970 |
+
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
| 971 |
+
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
| 972 |
+
|
| 973 |
+
quantized = self.quantizer.decode(codes)
|
| 974 |
+
if self.semantic_frame_rate == "25hz":
|
| 975 |
+
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
| 976 |
+
x, m_p, logs_p, y_mask = self.enc_p(
|
| 977 |
+
quantized,
|
| 978 |
+
y_lengths,
|
| 979 |
+
text,
|
| 980 |
+
text_lengths,
|
| 981 |
+
self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) if self.is_v2pro else ge,
|
| 982 |
+
speed,
|
| 983 |
+
)
|
| 984 |
+
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
| 985 |
+
|
| 986 |
+
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
| 987 |
+
|
| 988 |
+
o = self.dec((z * y_mask)[:, :, :], g=ge)
|
| 989 |
+
return o
|
| 990 |
+
|
| 991 |
+
def extract_latent(self, x) -> torch.Tensor:
|
| 992 |
+
ssl = self.ssl_proj(x)
|
| 993 |
+
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
| 994 |
+
return codes.transpose(0, 1)
|
| 995 |
+
|
| 996 |
+
|
| 997 |
+
class CFM(torch.nn.Module):
|
| 998 |
+
def __init__(self, in_channels, dit):
|
| 999 |
+
super().__init__()
|
| 1000 |
+
self.sigma_min = 1e-6
|
| 1001 |
+
|
| 1002 |
+
self.estimator = dit
|
| 1003 |
+
|
| 1004 |
+
self.in_channels = in_channels
|
| 1005 |
+
|
| 1006 |
+
self.criterion = torch.nn.MSELoss()
|
| 1007 |
+
|
| 1008 |
+
self.use_conditioner_cache = True
|
| 1009 |
+
|
| 1010 |
+
@torch.inference_mode()
|
| 1011 |
+
def inference(self, mu, x_lens, prompt, n_timesteps, temperature=1.0, inference_cfg_rate=0):
|
| 1012 |
+
"""Forward diffusion"""
|
| 1013 |
+
B, T = mu.size(0), mu.size(1)
|
| 1014 |
+
x = torch.randn([B, self.in_channels, T], device=mu.device, dtype=mu.dtype) * temperature
|
| 1015 |
+
prompt_len = prompt.size(-1)
|
| 1016 |
+
prompt_x = torch.zeros_like(x, dtype=mu.dtype)
|
| 1017 |
+
prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
|
| 1018 |
+
x[..., :prompt_len] = 0
|
| 1019 |
+
mu = mu.transpose(2, 1)
|
| 1020 |
+
t = 0
|
| 1021 |
+
d = 1 / n_timesteps
|
| 1022 |
+
text_cache = None
|
| 1023 |
+
text_cfg_cache = None
|
| 1024 |
+
dt_cache = None
|
| 1025 |
+
d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
|
| 1026 |
+
for j in range(n_timesteps):
|
| 1027 |
+
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
|
| 1028 |
+
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
| 1029 |
+
v_pred, text_emb, dt = self.estimator(
|
| 1030 |
+
x,
|
| 1031 |
+
prompt_x,
|
| 1032 |
+
x_lens,
|
| 1033 |
+
t_tensor,
|
| 1034 |
+
d_tensor,
|
| 1035 |
+
mu,
|
| 1036 |
+
use_grad_ckpt=False,
|
| 1037 |
+
drop_audio_cond=False,
|
| 1038 |
+
drop_text=False,
|
| 1039 |
+
infer=True,
|
| 1040 |
+
text_cache=text_cache,
|
| 1041 |
+
dt_cache=dt_cache,
|
| 1042 |
+
)
|
| 1043 |
+
v_pred = v_pred.transpose(2, 1)
|
| 1044 |
+
if self.use_conditioner_cache:
|
| 1045 |
+
text_cache = text_emb
|
| 1046 |
+
dt_cache = dt
|
| 1047 |
+
if inference_cfg_rate > 1e-5:
|
| 1048 |
+
neg, text_cfg_emb, _ = self.estimator(
|
| 1049 |
+
x,
|
| 1050 |
+
prompt_x,
|
| 1051 |
+
x_lens,
|
| 1052 |
+
t_tensor,
|
| 1053 |
+
d_tensor,
|
| 1054 |
+
mu,
|
| 1055 |
+
use_grad_ckpt=False,
|
| 1056 |
+
drop_audio_cond=True,
|
| 1057 |
+
drop_text=True,
|
| 1058 |
+
infer=True,
|
| 1059 |
+
text_cache=text_cfg_cache,
|
| 1060 |
+
dt_cache=dt_cache,
|
| 1061 |
+
)
|
| 1062 |
+
neg = neg.transpose(2, 1)
|
| 1063 |
+
if self.use_conditioner_cache:
|
| 1064 |
+
text_cfg_cache = text_cfg_emb
|
| 1065 |
+
v_pred = v_pred + (v_pred - neg) * inference_cfg_rate
|
| 1066 |
+
x = x + d * v_pred
|
| 1067 |
+
t = t + d
|
| 1068 |
+
x[:, :, :prompt_len] = 0
|
| 1069 |
+
return x
|
| 1070 |
+
|
| 1071 |
+
def forward(self, x1, x_lens, prompt_lens, mu, use_grad_ckpt):
|
| 1072 |
+
b, _, t = x1.shape
|
| 1073 |
+
t = torch.rand([b], device=mu.device, dtype=x1.dtype)
|
| 1074 |
+
x0 = torch.randn_like(x1, device=mu.device)
|
| 1075 |
+
vt = x1 - x0
|
| 1076 |
+
xt = x0 + t[:, None, None] * vt
|
| 1077 |
+
dt = torch.zeros_like(t, device=mu.device)
|
| 1078 |
+
prompt = torch.zeros_like(x1)
|
| 1079 |
+
for i in range(b):
|
| 1080 |
+
prompt[i, :, : prompt_lens[i]] = x1[i, :, : prompt_lens[i]]
|
| 1081 |
+
xt[i, :, : prompt_lens[i]] = 0
|
| 1082 |
+
gailv = 0.3 # if ttime()>1736250488 else 0.1
|
| 1083 |
+
if random.random() < gailv:
|
| 1084 |
+
base = torch.randint(2, 8, (t.shape[0],), device=mu.device)
|
| 1085 |
+
d = 1 / torch.pow(2, base)
|
| 1086 |
+
d_input = d.clone()
|
| 1087 |
+
d_input[d_input < 1e-2] = 0
|
| 1088 |
+
# with torch.no_grad():
|
| 1089 |
+
v_pred_1 = self.estimator(xt, prompt, x_lens, t, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
|
| 1090 |
+
# v_pred_1 = self.diffusion(xt, t, d_input, cond=conditioning).detach()
|
| 1091 |
+
x_mid = xt + d[:, None, None] * v_pred_1
|
| 1092 |
+
# v_pred_2 = self.diffusion(x_mid, t+d, d_input, cond=conditioning).detach()
|
| 1093 |
+
v_pred_2 = self.estimator(x_mid, prompt, x_lens, t + d, d_input, mu, use_grad_ckpt).transpose(2, 1).detach()
|
| 1094 |
+
vt = (v_pred_1 + v_pred_2) / 2
|
| 1095 |
+
vt = vt.detach()
|
| 1096 |
+
dt = 2 * d
|
| 1097 |
+
|
| 1098 |
+
vt_pred = self.estimator(xt, prompt, x_lens, t, dt, mu, use_grad_ckpt).transpose(2, 1)
|
| 1099 |
+
loss = 0
|
| 1100 |
+
for i in range(b):
|
| 1101 |
+
loss += self.criterion(vt_pred[i, :, prompt_lens[i] : x_lens[i]], vt[i, :, prompt_lens[i] : x_lens[i]])
|
| 1102 |
+
loss /= b
|
| 1103 |
+
|
| 1104 |
+
return loss
|
| 1105 |
+
|
| 1106 |
+
|
| 1107 |
+
def set_no_grad(net_g):
|
| 1108 |
+
for name, param in net_g.named_parameters():
|
| 1109 |
+
param.requires_grad = False
|
| 1110 |
+
|
| 1111 |
+
|
| 1112 |
+
class SynthesizerTrnV3(nn.Module):
|
| 1113 |
+
"""
|
| 1114 |
+
Synthesizer for Training
|
| 1115 |
+
"""
|
| 1116 |
+
|
| 1117 |
+
def __init__(
|
| 1118 |
+
self,
|
| 1119 |
+
spec_channels,
|
| 1120 |
+
segment_size,
|
| 1121 |
+
inter_channels,
|
| 1122 |
+
hidden_channels,
|
| 1123 |
+
filter_channels,
|
| 1124 |
+
n_heads,
|
| 1125 |
+
n_layers,
|
| 1126 |
+
kernel_size,
|
| 1127 |
+
p_dropout,
|
| 1128 |
+
resblock,
|
| 1129 |
+
resblock_kernel_sizes,
|
| 1130 |
+
resblock_dilation_sizes,
|
| 1131 |
+
upsample_rates,
|
| 1132 |
+
upsample_initial_channel,
|
| 1133 |
+
upsample_kernel_sizes,
|
| 1134 |
+
n_speakers=0,
|
| 1135 |
+
gin_channels=0,
|
| 1136 |
+
use_sdp=True,
|
| 1137 |
+
semantic_frame_rate=None,
|
| 1138 |
+
freeze_quantizer=None,
|
| 1139 |
+
version="v3",
|
| 1140 |
+
**kwargs,
|
| 1141 |
+
):
|
| 1142 |
+
super().__init__()
|
| 1143 |
+
self.spec_channels = spec_channels
|
| 1144 |
+
self.inter_channels = inter_channels
|
| 1145 |
+
self.hidden_channels = hidden_channels
|
| 1146 |
+
self.filter_channels = filter_channels
|
| 1147 |
+
self.n_heads = n_heads
|
| 1148 |
+
self.n_layers = n_layers
|
| 1149 |
+
self.kernel_size = kernel_size
|
| 1150 |
+
self.p_dropout = p_dropout
|
| 1151 |
+
self.resblock = resblock
|
| 1152 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
| 1153 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
| 1154 |
+
self.upsample_rates = upsample_rates
|
| 1155 |
+
self.upsample_initial_channel = upsample_initial_channel
|
| 1156 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 1157 |
+
self.segment_size = segment_size
|
| 1158 |
+
self.n_speakers = n_speakers
|
| 1159 |
+
self.gin_channels = gin_channels
|
| 1160 |
+
self.version = version
|
| 1161 |
+
|
| 1162 |
+
self.model_dim = 512
|
| 1163 |
+
self.use_sdp = use_sdp
|
| 1164 |
+
self.enc_p = TextEncoder(
|
| 1165 |
+
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
| 1166 |
+
)
|
| 1167 |
+
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels)
|
| 1168 |
+
|
| 1169 |
+
ssl_dim = 768
|
| 1170 |
+
assert semantic_frame_rate in ["25hz", "50hz"]
|
| 1171 |
+
self.semantic_frame_rate = semantic_frame_rate
|
| 1172 |
+
if semantic_frame_rate == "25hz":
|
| 1173 |
+
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
| 1174 |
+
else:
|
| 1175 |
+
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
| 1176 |
+
|
| 1177 |
+
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
| 1178 |
+
self.freeze_quantizer = freeze_quantizer
|
| 1179 |
+
inter_channels2 = 512
|
| 1180 |
+
self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
|
| 1181 |
+
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
|
| 1182 |
+
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
|
| 1183 |
+
self.cfm = CFM(
|
| 1184 |
+
100,
|
| 1185 |
+
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
|
| 1186 |
+
) # text_dim is condition feature dim
|
| 1187 |
+
if self.freeze_quantizer is True:
|
| 1188 |
+
set_no_grad(self.ssl_proj)
|
| 1189 |
+
set_no_grad(self.quantizer)
|
| 1190 |
+
set_no_grad(self.enc_p)
|
| 1191 |
+
|
| 1192 |
+
def forward(
|
| 1193 |
+
self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths, use_grad_ckpt
|
| 1194 |
+
): # ssl_lengths no need now
|
| 1195 |
+
with autocast(enabled=False):
|
| 1196 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
| 1197 |
+
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
| 1198 |
+
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
|
| 1199 |
+
with maybe_no_grad:
|
| 1200 |
+
if self.freeze_quantizer:
|
| 1201 |
+
self.ssl_proj.eval() #
|
| 1202 |
+
self.quantizer.eval()
|
| 1203 |
+
self.enc_p.eval()
|
| 1204 |
+
ssl = self.ssl_proj(ssl)
|
| 1205 |
+
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
| 1206 |
+
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
| 1207 |
+
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
| 1208 |
+
fea = self.bridge(x)
|
| 1209 |
+
fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
|
| 1210 |
+
fea, y_mask_ = self.wns1(
|
| 1211 |
+
fea, mel_lengths, ge
|
| 1212 |
+
) ##If the 1-minute fine-tuning works fine, no need to manually adjust the learning rate.
|
| 1213 |
+
B = ssl.shape[0]
|
| 1214 |
+
prompt_len_max = mel_lengths * 2 / 3
|
| 1215 |
+
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long)
|
| 1216 |
+
minn = min(mel.shape[-1], fea.shape[-1])
|
| 1217 |
+
mel = mel[:, :, :minn]
|
| 1218 |
+
fea = fea[:, :, :minn]
|
| 1219 |
+
cfm_loss = self.cfm(mel, mel_lengths, prompt_len, fea, use_grad_ckpt)
|
| 1220 |
+
return cfm_loss
|
| 1221 |
+
|
| 1222 |
+
@torch.no_grad()
|
| 1223 |
+
def decode_encp(self, codes, text, refer, ge=None, speed=1):
|
| 1224 |
+
# print(2333333,refer.shape)
|
| 1225 |
+
# ge=None
|
| 1226 |
+
if ge is None:
|
| 1227 |
+
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
| 1228 |
+
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
| 1229 |
+
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
| 1230 |
+
y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
|
| 1231 |
+
if speed == 1:
|
| 1232 |
+
sizee = int(codes.size(2) * (3.875 if self.version == "v3" else 4))
|
| 1233 |
+
else:
|
| 1234 |
+
sizee = int(codes.size(2) * (3.875 if self.version == "v3" else 4) / speed) + 1
|
| 1235 |
+
y_lengths1 = torch.LongTensor([sizee]).to(codes.device)
|
| 1236 |
+
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
| 1237 |
+
|
| 1238 |
+
quantized = self.quantizer.decode(codes)
|
| 1239 |
+
if self.semantic_frame_rate == "25hz":
|
| 1240 |
+
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
| 1241 |
+
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
|
| 1242 |
+
fea = self.bridge(x)
|
| 1243 |
+
fea = F.interpolate(fea, scale_factor=(1.875 if self.version == "v3" else 2), mode="nearest") ##BCT
|
| 1244 |
+
####more wn paramter to learn mel
|
| 1245 |
+
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
|
| 1246 |
+
return fea, ge
|
| 1247 |
+
|
| 1248 |
+
def extract_latent(self, x):
|
| 1249 |
+
ssl = self.ssl_proj(x)
|
| 1250 |
+
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
| 1251 |
+
return codes.transpose(0, 1)
|
| 1252 |
+
|
| 1253 |
+
|
| 1254 |
+
class SynthesizerTrnV3b(nn.Module):
|
| 1255 |
+
"""
|
| 1256 |
+
Synthesizer for Training
|
| 1257 |
+
"""
|
| 1258 |
+
|
| 1259 |
+
def __init__(
|
| 1260 |
+
self,
|
| 1261 |
+
spec_channels,
|
| 1262 |
+
segment_size,
|
| 1263 |
+
inter_channels,
|
| 1264 |
+
hidden_channels,
|
| 1265 |
+
filter_channels,
|
| 1266 |
+
n_heads,
|
| 1267 |
+
n_layers,
|
| 1268 |
+
kernel_size,
|
| 1269 |
+
p_dropout,
|
| 1270 |
+
resblock,
|
| 1271 |
+
resblock_kernel_sizes,
|
| 1272 |
+
resblock_dilation_sizes,
|
| 1273 |
+
upsample_rates,
|
| 1274 |
+
upsample_initial_channel,
|
| 1275 |
+
upsample_kernel_sizes,
|
| 1276 |
+
n_speakers=0,
|
| 1277 |
+
gin_channels=0,
|
| 1278 |
+
use_sdp=True,
|
| 1279 |
+
semantic_frame_rate=None,
|
| 1280 |
+
freeze_quantizer=None,
|
| 1281 |
+
**kwargs,
|
| 1282 |
+
):
|
| 1283 |
+
super().__init__()
|
| 1284 |
+
self.spec_channels = spec_channels
|
| 1285 |
+
self.inter_channels = inter_channels
|
| 1286 |
+
self.hidden_channels = hidden_channels
|
| 1287 |
+
self.filter_channels = filter_channels
|
| 1288 |
+
self.n_heads = n_heads
|
| 1289 |
+
self.n_layers = n_layers
|
| 1290 |
+
self.kernel_size = kernel_size
|
| 1291 |
+
self.p_dropout = p_dropout
|
| 1292 |
+
self.resblock = resblock
|
| 1293 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
| 1294 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
| 1295 |
+
self.upsample_rates = upsample_rates
|
| 1296 |
+
self.upsample_initial_channel = upsample_initial_channel
|
| 1297 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 1298 |
+
self.segment_size = segment_size
|
| 1299 |
+
self.n_speakers = n_speakers
|
| 1300 |
+
self.gin_channels = gin_channels
|
| 1301 |
+
|
| 1302 |
+
self.model_dim = 512
|
| 1303 |
+
self.use_sdp = use_sdp
|
| 1304 |
+
self.enc_p = TextEncoder(
|
| 1305 |
+
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
| 1306 |
+
)
|
| 1307 |
+
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
|
| 1308 |
+
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
|
| 1309 |
+
self.dec = Generator(
|
| 1310 |
+
inter_channels,
|
| 1311 |
+
resblock,
|
| 1312 |
+
resblock_kernel_sizes,
|
| 1313 |
+
resblock_dilation_sizes,
|
| 1314 |
+
upsample_rates,
|
| 1315 |
+
upsample_initial_channel,
|
| 1316 |
+
upsample_kernel_sizes,
|
| 1317 |
+
gin_channels=gin_channels,
|
| 1318 |
+
)
|
| 1319 |
+
self.enc_q = PosteriorEncoder(
|
| 1320 |
+
spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels
|
| 1321 |
+
)
|
| 1322 |
+
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
| 1323 |
+
|
| 1324 |
+
ssl_dim = 768
|
| 1325 |
+
assert semantic_frame_rate in ["25hz", "50hz"]
|
| 1326 |
+
self.semantic_frame_rate = semantic_frame_rate
|
| 1327 |
+
if semantic_frame_rate == "25hz":
|
| 1328 |
+
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
| 1329 |
+
else:
|
| 1330 |
+
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
| 1331 |
+
|
| 1332 |
+
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
| 1333 |
+
self.freeze_quantizer = freeze_quantizer
|
| 1334 |
+
|
| 1335 |
+
inter_channels2 = 512
|
| 1336 |
+
self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
|
| 1337 |
+
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
|
| 1338 |
+
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
|
| 1339 |
+
self.cfm = CFM(
|
| 1340 |
+
100,
|
| 1341 |
+
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
|
| 1342 |
+
) # text_dim is condition feature dim
|
| 1343 |
+
|
| 1344 |
+
def forward(self, ssl, y, mel, ssl_lengths, y_lengths, text, text_lengths, mel_lengths): # ssl_lengths no need now
|
| 1345 |
+
with autocast(enabled=False):
|
| 1346 |
+
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
| 1347 |
+
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
| 1348 |
+
# ge = self.ref_enc(y * y_mask, y_mask)#change back, new spec setting is whole 24k
|
| 1349 |
+
# ge=None
|
| 1350 |
+
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
|
| 1351 |
+
with maybe_no_grad:
|
| 1352 |
+
if self.freeze_quantizer:
|
| 1353 |
+
self.ssl_proj.eval()
|
| 1354 |
+
self.quantizer.eval()
|
| 1355 |
+
ssl = self.ssl_proj(ssl)
|
| 1356 |
+
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])
|
| 1357 |
+
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
| 1358 |
+
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
| 1359 |
+
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
| 1360 |
+
z_p = self.flow(z, y_mask, g=ge)
|
| 1361 |
+
z_slice, ids_slice = commons.rand_slice_segments(z, y_lengths, self.segment_size)
|
| 1362 |
+
o = self.dec(z_slice, g=ge)
|
| 1363 |
+
fea = self.bridge(x)
|
| 1364 |
+
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
| 1365 |
+
fea, y_mask_ = self.wns1(fea, mel_lengths, ge)
|
| 1366 |
+
learned_mel = self.linear_mel(fea)
|
| 1367 |
+
B = ssl.shape[0]
|
| 1368 |
+
prompt_len_max = mel_lengths * 2 / 3
|
| 1369 |
+
prompt_len = (torch.rand([B], device=fea.device) * prompt_len_max).floor().to(dtype=torch.long) #
|
| 1370 |
+
minn = min(mel.shape[-1], fea.shape[-1])
|
| 1371 |
+
mel = mel[:, :, :minn]
|
| 1372 |
+
fea = fea[:, :, :minn]
|
| 1373 |
+
cfm_loss = self.cfm(mel, mel_lengths, prompt_len, fea) # fea==cond,y_lengths==target_mel_lengths#ge not need
|
| 1374 |
+
return (
|
| 1375 |
+
commit_loss,
|
| 1376 |
+
cfm_loss,
|
| 1377 |
+
F.mse_loss(learned_mel, mel),
|
| 1378 |
+
o,
|
| 1379 |
+
ids_slice,
|
| 1380 |
+
y_mask,
|
| 1381 |
+
y_mask,
|
| 1382 |
+
(z, z_p, m_p, logs_p, m_q, logs_q),
|
| 1383 |
+
quantized,
|
| 1384 |
+
)
|
| 1385 |
+
|
| 1386 |
+
@torch.no_grad()
|
| 1387 |
+
def decode_encp(self, codes, text, refer, ge=None):
|
| 1388 |
+
# print(2333333,refer.shape)
|
| 1389 |
+
# ge=None
|
| 1390 |
+
if ge is None:
|
| 1391 |
+
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
| 1392 |
+
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
| 1393 |
+
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
| 1394 |
+
y_lengths = torch.LongTensor([int(codes.size(2) * 2)]).to(codes.device)
|
| 1395 |
+
y_lengths1 = torch.LongTensor([int(codes.size(2) * 2.5 * 1.5)]).to(codes.device)
|
| 1396 |
+
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
| 1397 |
+
|
| 1398 |
+
quantized = self.quantizer.decode(codes)
|
| 1399 |
+
if self.semantic_frame_rate == "25hz":
|
| 1400 |
+
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
| 1401 |
+
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
| 1402 |
+
fea = self.bridge(x)
|
| 1403 |
+
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
| 1404 |
+
####more wn paramter to learn mel
|
| 1405 |
+
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
|
| 1406 |
+
return fea, ge
|
| 1407 |
+
|
| 1408 |
+
def extract_latent(self, x) -> torch.Tensor:
|
| 1409 |
+
ssl = self.ssl_proj(x)
|
| 1410 |
+
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
| 1411 |
+
return codes.transpose(0, 1)
|