Spaces:
Runtime error
Runtime error
Hugo Flores Garcia commited on
Commit ·
31b771c
1
Parent(s): a66dc9c
dropping torch.compile for now
Browse files- scripts/exp/train.py +10 -4
- scripts/utils/split_long_audio_file.py +34 -0
scripts/exp/train.py
CHANGED
|
@@ -29,6 +29,9 @@ from audiotools.ml.decorators import (
|
|
| 29 |
|
| 30 |
import loralib as lora
|
| 31 |
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
# Enable cudnn autotuner to speed up training
|
| 34 |
# (can be altered by the funcs.seed function)
|
|
@@ -510,14 +513,14 @@ def load(
|
|
| 510 |
|
| 511 |
if args["fine_tune"]:
|
| 512 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
| 513 |
-
model =
|
| 514 |
VampNet.load(location=Path(fine_tune_checkpoint),
|
| 515 |
map_location="cpu",
|
| 516 |
)
|
| 517 |
)
|
| 518 |
|
| 519 |
|
| 520 |
-
model =
|
| 521 |
model = accel.prepare_model(model)
|
| 522 |
|
| 523 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
|
@@ -601,7 +604,7 @@ def train(
|
|
| 601 |
accel=accel,
|
| 602 |
tracker=tracker,
|
| 603 |
save_path=save_path)
|
| 604 |
-
|
| 605 |
|
| 606 |
train_dataloader = accel.prepare_dataloader(
|
| 607 |
state.train_data,
|
|
@@ -616,13 +619,15 @@ def train(
|
|
| 616 |
num_workers=num_workers,
|
| 617 |
batch_size=batch_size,
|
| 618 |
collate_fn=state.val_data.collate,
|
| 619 |
-
persistent_workers=
|
| 620 |
)
|
|
|
|
| 621 |
|
| 622 |
|
| 623 |
|
| 624 |
if fine_tune:
|
| 625 |
lora.mark_only_lora_as_trainable(state.model)
|
|
|
|
| 626 |
|
| 627 |
# Wrap the functions so that they neatly track in TensorBoard + progress bars
|
| 628 |
# and only run when specific conditions are met.
|
|
@@ -637,6 +642,7 @@ def train(
|
|
| 637 |
save_samples = when(lambda: accel.local_rank == 0)(save_samples)
|
| 638 |
checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
|
| 639 |
|
|
|
|
| 640 |
with tracker.live:
|
| 641 |
for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
|
| 642 |
train_loop(state, batch, accel)
|
|
|
|
| 29 |
|
| 30 |
import loralib as lora
|
| 31 |
|
| 32 |
+
import torch._dynamo
|
| 33 |
+
torch._dynamo.config.verbose=True
|
| 34 |
+
|
| 35 |
|
| 36 |
# Enable cudnn autotuner to speed up training
|
| 37 |
# (can be altered by the funcs.seed function)
|
|
|
|
| 513 |
|
| 514 |
if args["fine_tune"]:
|
| 515 |
assert fine_tune_checkpoint is not None, "Must provide a fine-tune checkpoint"
|
| 516 |
+
model = (
|
| 517 |
VampNet.load(location=Path(fine_tune_checkpoint),
|
| 518 |
map_location="cpu",
|
| 519 |
)
|
| 520 |
)
|
| 521 |
|
| 522 |
|
| 523 |
+
model = VampNet() if model is None else model
|
| 524 |
model = accel.prepare_model(model)
|
| 525 |
|
| 526 |
# assert accel.unwrap(model).n_codebooks == codec.quantizer.n_codebooks
|
|
|
|
| 604 |
accel=accel,
|
| 605 |
tracker=tracker,
|
| 606 |
save_path=save_path)
|
| 607 |
+
print("initialized state.")
|
| 608 |
|
| 609 |
train_dataloader = accel.prepare_dataloader(
|
| 610 |
state.train_data,
|
|
|
|
| 619 |
num_workers=num_workers,
|
| 620 |
batch_size=batch_size,
|
| 621 |
collate_fn=state.val_data.collate,
|
| 622 |
+
persistent_workers=num_workers > 0,
|
| 623 |
)
|
| 624 |
+
print("initialized dataloader.")
|
| 625 |
|
| 626 |
|
| 627 |
|
| 628 |
if fine_tune:
|
| 629 |
lora.mark_only_lora_as_trainable(state.model)
|
| 630 |
+
print("marked only lora as trainable.")
|
| 631 |
|
| 632 |
# Wrap the functions so that they neatly track in TensorBoard + progress bars
|
| 633 |
# and only run when specific conditions are met.
|
|
|
|
| 642 |
save_samples = when(lambda: accel.local_rank == 0)(save_samples)
|
| 643 |
checkpoint = when(lambda: accel.local_rank == 0)(checkpoint)
|
| 644 |
|
| 645 |
+
print("starting training loop.")
|
| 646 |
with tracker.live:
|
| 647 |
for tracker.step, batch in enumerate(train_dataloader, start=tracker.step):
|
| 648 |
train_loop(state, batch, accel)
|
scripts/utils/split_long_audio_file.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import argbind
|
| 3 |
+
|
| 4 |
+
import audiotools as at
|
| 5 |
+
import tqdm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@argbind.bind(without_prefix=True)
|
| 9 |
+
def split_long_audio_file(
|
| 10 |
+
file: str = None,
|
| 11 |
+
max_chunk_size_s: int = 60*10
|
| 12 |
+
):
|
| 13 |
+
file = Path(file)
|
| 14 |
+
output_dir = file.parent / file.stem
|
| 15 |
+
output_dir.mkdir()
|
| 16 |
+
|
| 17 |
+
sig = at.AudioSignal(file)
|
| 18 |
+
|
| 19 |
+
# split into chunks
|
| 20 |
+
for i, sig in tqdm.tqdm(enumerate(sig.windows(
|
| 21 |
+
window_duration=max_chunk_size_s, hop_duration=max_chunk_size_s/2,
|
| 22 |
+
preprocess=True))
|
| 23 |
+
):
|
| 24 |
+
sig.write(output_dir / f"{i}.wav")
|
| 25 |
+
|
| 26 |
+
print(f"wrote {len(list(output_dir.glob('*.wav')))} files to {output_dir}")
|
| 27 |
+
|
| 28 |
+
return output_dir
|
| 29 |
+
|
| 30 |
+
if __name__ == "__main__":
|
| 31 |
+
args = argbind.parse_args()
|
| 32 |
+
|
| 33 |
+
with argbind.scope(args):
|
| 34 |
+
split_long_audio_file()
|