Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from omegaconf import OmegaConf, DictConfig | |
from dataclasses import dataclass, field | |
from typing import Any, Dict, List, Optional, Union | |
from datetime import datetime | |
class ExperimentConfig: | |
name: str = "default" | |
tag: str = "" | |
use_timestamp: bool = False | |
timestamp: Optional[str] = None | |
exp_root_dir: str = "outputs" | |
### these shouldn't be set manually | |
exp_dir: str = "outputs/default" | |
trial_name: str = "exp" | |
trial_dir: str = "outputs/default/exp" | |
### | |
resume: Optional[str] = None | |
ckpt_path: Optional[str] = None | |
data: dict = field(default_factory=dict) | |
model_pl: dict = field(default_factory=dict) | |
trainer: dict = field(default_factory=dict) | |
checkpoint: dict = field(default_factory=dict) | |
checkpoint_epoch: Optional[dict] = None | |
wandb: dict = field(default_factory=dict) | |
def load_config(*yamls: str, cli_args: list = [], from_string=False, **kwargs) -> Any: | |
if from_string: | |
yaml_confs = [OmegaConf.create(s) for s in yamls] | |
else: | |
yaml_confs = [OmegaConf.load(f) for f in yamls] | |
cli_conf = OmegaConf.from_cli(cli_args) | |
cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs) | |
OmegaConf.resolve(cfg) | |
assert isinstance(cfg, DictConfig) | |
scfg = parse_structured(ExperimentConfig, cfg) | |
return scfg | |
def config_to_primitive(config, resolve: bool = True) -> Any: | |
return OmegaConf.to_container(config, resolve=resolve) | |
def dump_config(path: str, config) -> None: | |
with open(path, "w") as fp: | |
OmegaConf.save(config=config, f=fp) | |
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: | |
scfg = OmegaConf.structured(fields(**cfg)) | |
return scfg |