|
import os |
|
from source.trainer import EDGSTrainer |
|
from source.utils_aux import set_seed |
|
import omegaconf |
|
import wandb |
|
import hydra |
|
from argparse import Namespace |
|
from omegaconf import OmegaConf |
|
|
|
|
|
@hydra.main(config_path="configs", config_name="train", version_base="1.2") |
|
def main(cfg: omegaconf.DictConfig): |
|
_ = wandb.init(entity=cfg.wandb.entity, |
|
project=cfg.wandb.project, |
|
config=omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True), |
|
tags=[cfg.wandb.tag], |
|
name = cfg.wandb.name, |
|
mode = cfg.wandb.mode) |
|
omegaconf.OmegaConf.resolve(cfg) |
|
set_seed(cfg.seed) |
|
|
|
|
|
print("Output folder: {}".format(cfg.gs.dataset.model_path)) |
|
os.makedirs(cfg.gs.dataset.model_path, exist_ok=True) |
|
with open(os.path.join(cfg.gs.dataset.model_path, "cfg_args"), 'w') as cfg_log_f: |
|
params = { |
|
"sh_degree": 3, |
|
"source_path": cfg.gs.dataset.source_path, |
|
"model_path": cfg.gs.dataset.model_path, |
|
"images": cfg.gs.dataset.images, |
|
"depths": "", |
|
"resolution": -1, |
|
"_white_background": cfg.gs.dataset.white_background, |
|
"train_test_exp": False, |
|
"data_device": cfg.gs.dataset.data_device, |
|
"eval": False, |
|
"convert_SHs_python": False, |
|
"compute_cov3D_python": False, |
|
"debug": False, |
|
"antialiasing": False |
|
} |
|
cfg_log_f.write(str(Namespace(**params))) |
|
|
|
|
|
gs = hydra.utils.instantiate(cfg.gs) |
|
|
|
|
|
trainer = EDGSTrainer(GS=gs, |
|
training_config=cfg.gs.opt, |
|
device=cfg.device) |
|
|
|
trainer.load_checkpoints(cfg.load) |
|
trainer.timer.start() |
|
trainer.init_with_corr(cfg.init_wC) |
|
trainer.train(cfg.train) |
|
|
|
|
|
wandb.finish() |
|
print("\nTraining complete.") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|