Spaces:
Runtime error
Runtime error
Merge branch 'main' of github.com:descriptinc/lyrebird-vampnet into main
Browse files- requirements.txt +1 -1
- scripts/exp/train.py +5 -1
- setup.py +3 -1
- vampnet/modules/base.py +2 -2
requirements.txt
CHANGED
|
@@ -2,12 +2,12 @@ argbind>=0.3.1
|
|
| 2 |
pytorch-ignite
|
| 3 |
rich
|
| 4 |
audiotools @ git+https://github.com/descriptinc/[email protected]
|
|
|
|
| 5 |
tqdm
|
| 6 |
tensorboard
|
| 7 |
google-cloud-logging==2.2.0
|
| 8 |
pytest
|
| 9 |
pytest-cov
|
| 10 |
-
papaya_client @ git+https://github.com/descriptinc/lyrebird-papaya.git@master
|
| 11 |
pynvml
|
| 12 |
psutil
|
| 13 |
pandas
|
|
|
|
| 2 |
pytorch-ignite
|
| 3 |
rich
|
| 4 |
audiotools @ git+https://github.com/descriptinc/[email protected]
|
| 5 |
+
lac @ git+https://github.com/descriptinc/lyrebird-audio-codec.git@main
|
| 6 |
tqdm
|
| 7 |
tensorboard
|
| 8 |
google-cloud-logging==2.2.0
|
| 9 |
pytest
|
| 10 |
pytest-cov
|
|
|
|
| 11 |
pynvml
|
| 12 |
psutil
|
| 13 |
pandas
|
scripts/exp/train.py
CHANGED
|
@@ -59,7 +59,7 @@ IGNORE_INDEX = -100
|
|
| 59 |
@argbind.bind("train", "val", without_prefix=True)
|
| 60 |
def build_transform():
|
| 61 |
transform = transforms.Compose(
|
| 62 |
-
tfm.VolumeNorm(("uniform", -32, -
|
| 63 |
tfm.VolumeChange(("uniform", -6, 3)),
|
| 64 |
tfm.RescaleAudio(),
|
| 65 |
)
|
|
@@ -250,6 +250,7 @@ def train(
|
|
| 250 |
max_epochs: int = int(100e3),
|
| 251 |
epoch_length: int = 1000,
|
| 252 |
save_audio_epochs: int = 10,
|
|
|
|
| 253 |
batch_size: int = 48,
|
| 254 |
grad_acc_steps: int = 1,
|
| 255 |
val_idx: list = [0, 1, 2, 3, 4],
|
|
@@ -506,6 +507,9 @@ def train(
|
|
| 506 |
loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train"
|
| 507 |
self.print(f"Saving to {str(Path('.').absolute())}")
|
| 508 |
|
|
|
|
|
|
|
|
|
|
| 509 |
if self.is_best(engine, loss_key):
|
| 510 |
self.print(f"Best model so far")
|
| 511 |
tags.append("best")
|
|
|
|
| 59 |
@argbind.bind("train", "val", without_prefix=True)
|
| 60 |
def build_transform():
|
| 61 |
transform = transforms.Compose(
|
| 62 |
+
tfm.VolumeNorm(("uniform", -32, -20)),
|
| 63 |
tfm.VolumeChange(("uniform", -6, 3)),
|
| 64 |
tfm.RescaleAudio(),
|
| 65 |
)
|
|
|
|
| 250 |
max_epochs: int = int(100e3),
|
| 251 |
epoch_length: int = 1000,
|
| 252 |
save_audio_epochs: int = 10,
|
| 253 |
+
save_epochs: list = [10, 50, 100, 200, 300, 400,],
|
| 254 |
batch_size: int = 48,
|
| 255 |
grad_acc_steps: int = 1,
|
| 256 |
val_idx: list = [0, 1, 2, 3, 4],
|
|
|
|
| 507 |
loss_key = "loss/val" if "loss/val" in metadata["logs"] else "loss/train"
|
| 508 |
self.print(f"Saving to {str(Path('.').absolute())}")
|
| 509 |
|
| 510 |
+
if self.state.epoch in save_epochs:
|
| 511 |
+
tags.append(f"epoch={self.state.epoch}")
|
| 512 |
+
|
| 513 |
if self.is_best(engine, loss_key):
|
| 514 |
self.print(f"Best model so far")
|
| 515 |
tags.append("best")
|
setup.py
CHANGED
|
@@ -30,11 +30,13 @@ setup(
|
|
| 30 |
"argbind>=0.3.2",
|
| 31 |
"pytorch-ignite",
|
| 32 |
"rich",
|
| 33 |
-
"audiotools @ git+https://github.com/descriptinc/[email protected].
|
|
|
|
| 34 |
"tqdm",
|
| 35 |
"tensorboard",
|
| 36 |
"google-cloud-logging==2.2.0",
|
| 37 |
"torchmetrics>=0.7.3",
|
| 38 |
"einops",
|
|
|
|
| 39 |
],
|
| 40 |
)
|
|
|
|
| 30 |
"argbind>=0.3.2",
|
| 31 |
"pytorch-ignite",
|
| 32 |
"rich",
|
| 33 |
+
"audiotools @ git+https://github.com/descriptinc/[email protected].3",
|
| 34 |
+
"lac @ git+https://github.com/descriptinc/lyrebird-audio-codec.git@main",
|
| 35 |
"tqdm",
|
| 36 |
"tensorboard",
|
| 37 |
"google-cloud-logging==2.2.0",
|
| 38 |
"torchmetrics>=0.7.3",
|
| 39 |
"einops",
|
| 40 |
+
"flash-attn",
|
| 41 |
],
|
| 42 |
)
|
vampnet/modules/base.py
CHANGED
|
@@ -153,7 +153,7 @@ class VampBase(at.ml.BaseModel):
|
|
| 153 |
sampling_steps: int = 12,
|
| 154 |
start_tokens: Optional[torch.Tensor] = None,
|
| 155 |
mask: Optional[torch.Tensor] = None,
|
| 156 |
-
temperature: Union[float, Tuple[float, float]] =
|
| 157 |
top_k: int = None,
|
| 158 |
sample: str = "gumbel",
|
| 159 |
renoise_mode: str = "start",
|
|
@@ -262,7 +262,7 @@ class VampBase(at.ml.BaseModel):
|
|
| 262 |
sampling_steps: int = 24,
|
| 263 |
start_tokens: Optional[torch.Tensor] = None,
|
| 264 |
mask: Optional[torch.Tensor] = None,
|
| 265 |
-
temperature: Union[float, Tuple[float, float]] =
|
| 266 |
top_k: int = None,
|
| 267 |
sample: str = "multinomial",
|
| 268 |
typical_filtering=False,
|
|
|
|
| 153 |
sampling_steps: int = 12,
|
| 154 |
start_tokens: Optional[torch.Tensor] = None,
|
| 155 |
mask: Optional[torch.Tensor] = None,
|
| 156 |
+
temperature: Union[float, Tuple[float, float]] = 0.8,
|
| 157 |
top_k: int = None,
|
| 158 |
sample: str = "gumbel",
|
| 159 |
renoise_mode: str = "start",
|
|
|
|
| 262 |
sampling_steps: int = 24,
|
| 263 |
start_tokens: Optional[torch.Tensor] = None,
|
| 264 |
mask: Optional[torch.Tensor] = None,
|
| 265 |
+
temperature: Union[float, Tuple[float, float]] = 0.8,
|
| 266 |
top_k: int = None,
|
| 267 |
sample: str = "multinomial",
|
| 268 |
typical_filtering=False,
|