Spaces:
Running
Running
| # Base configuration for training a model | |
| paths: | |
| run_dir: results/${project} | |
| ckpt_dir: ${paths.run_dir}/checkpoints | |
| hydra: | |
| run: | |
| dir: ${paths.run_dir} | |
| # Lightning Trainer | |
| trainer: | |
| _target_: lightning.pytorch.trainer.Trainer | |
| default_root_dir: ${paths.run_dir} | |
| accelerator: gpu | |
| num_nodes: 1 | |
| devices: auto | |
| strategy: | |
| _target_: lightning.pytorch.strategies.DDPStrategy | |
| process_group_backend: nccl # This should be override when training on windows | |
| precision: bf16-mixed | |
| # disable validation by epoch end | |
| check_val_every_n_epoch: null | |
| val_check_interval: 5000 | |
| max_steps: 100_000 | |
| # Use torch.backends.cudnn.benchmark to speed up training | |
| benchmark: true | |
| # Callbacks | |
| callbacks: | |
| model_checkpoint: | |
| _target_: lightning.pytorch.callbacks.ModelCheckpoint | |
| dirpath: ${paths.ckpt_dir} | |
| filename: "step_{step:09d}" | |
| save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt | |
| save_top_k: 5 # save 5 latest checkpoints | |
| monitor: step # use step to monitor checkpoints | |
| mode: max # save the latest checkpoint with the highest global_step | |
| every_n_epochs: null # don't save checkpoints by epoch end | |
| every_n_train_steps: 5000 # save checkpoints every 5000 steps | |
| auto_insert_metric_name: false | |
| model_summary: | |
| _target_: lightning.pytorch.callbacks.ModelSummary | |
| max_depth: 2 # the maximum depth of layer nesting that the summary will include | |
| learning_rate_monitor: | |
| _target_: lightning.pytorch.callbacks.LearningRateMonitor | |
| logging_interval: step | |
| log_momentum: false | |
| grad_norm_monitor: | |
| _target_: fish_speech.callbacks.GradNormMonitor | |
| norm_type: 2 | |
| logging_interval: step | |
| # Logger | |
| logger: | |
| tensorboard: | |
| _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger | |
| save_dir: "${paths.run_dir}/tensorboard/" | |
| name: null | |
| log_graph: false | |
| default_hp_metric: true | |
| prefix: "" | |
| # wandb: | |
| # _target_: lightning.pytorch.loggers.wandb.WandbLogger | |
| # # name: "" # name of the run (normally generated by wandb) | |
| # save_dir: "${paths.run_dir}" | |
| # offline: False | |
| # id: null # pass correct id to resume experiment! | |
| # anonymous: null # enable anonymous logging | |
| # project: "fish-speech" | |
| # log_model: False # upload lightning ckpts | |
| # prefix: "" # a string to put at the beginning of metric keys | |
| # # entity: "" # set to name of your wandb team | |
| # group: "" | |
| # tags: ["vq", "hq", "finetune"] | |
| # job_type: "" | |
| # Loop | |
| train: true | |
| test: false | |