Spaces:
Runtime error
Runtime error
| from utils.dist import * | |
| from parse import * | |
| from utils.util import find_free_port | |
| import torch.multiprocessing as mp | |
| import torch.distributed | |
| from importlib import import_module | |
| import os | |
| import glob | |
| from inputs import args_parser | |
| def main_worker(rank, opt): | |
| if 'local_rank' not in opt: | |
| opt['local_rank'] = opt['global_rank'] = rank | |
| if opt['distributed']: | |
| torch.cuda.set_device(int(opt['local_rank'])) | |
| torch.distributed.init_process_group(backend='nccl', | |
| init_method=opt['init_method'], | |
| world_size=opt['world_size'], | |
| rank=opt['global_rank'], | |
| group_name='mtorch') | |
| print('using GPU {}-{} for training'.format( | |
| int(opt['global_rank']), int(opt['local_rank']))) | |
| if torch.cuda.is_available(): | |
| opt['device'] = torch.device("cuda:{}".format(opt['local_rank'])) | |
| else: | |
| opt['device'] = 'cpu' | |
| pkg = import_module('networks.{}'.format(opt['network'])) | |
| trainer = pkg.Network(opt, rank) | |
| trainer.train() | |
| def main(args_obj): | |
| opt = parse(args_obj) | |
| opt['world_size'] = get_world_size() | |
| free_port = find_free_port() | |
| master_ip = get_master_ip() | |
| opt['init_method'] = "tcp://{}:{}".format(master_ip, free_port) | |
| opt['distributed'] = True if opt['world_size'] > 1 else False | |
| print(f'World size is: {opt["world_size"]}, and init_method is: {opt["init_method"]}') | |
| print('Import network module: ', opt['network']) | |
| checkpoint, config = glob.glob(os.path.join(opt['flow_checkPoint'], '*.tar'))[0], \ | |
| glob.glob(os.path.join(opt['flow_checkPoint'], '*.yaml'))[0] | |
| with open(config, 'r') as f: | |
| configs = yaml.full_load(f) | |
| opt['flow_config'] = configs | |
| opt['flow_checkPoint'] = checkpoint | |
| if args.finetune == 1: | |
| opt['finetune'] = True | |
| else: | |
| opt['finetune'] = False | |
| if opt['gen_state'] != '': | |
| opt['path']['gen_state'] = opt['gen_state'] | |
| if opt['dis_state'] != '': | |
| opt['path']['dis_state'] = opt['dis_state'] | |
| if opt['opt_state'] != '': | |
| opt['path']['opt_state'] = opt['opt_state'] | |
| opt['input_resolution'] = (opt['res_h'], opt['res_w']) | |
| opt['kernel_size'] = (opt['kernel_size_h'], opt['kernel_size_w']) | |
| opt['stride'] = (opt['stride_h'], opt['stride_w']) | |
| opt['padding'] = (opt['pad_h'], opt['pad_w']) | |
| print('model is: {}'.format(opt['model'])) | |
| if get_master_ip() == "127.0.0.1": | |
| # localhost | |
| mp.spawn(main_worker, nprocs=opt['world_size'], args=(opt,)) | |
| else: | |
| # multiple processes should be launched by openmpi | |
| opt['local_rank'] = get_local_rank() | |
| opt['global_rank'] = get_global_rank() | |
| main_worker(-1, opt) | |
| if __name__ == '__main__': | |
| args = args_parser() | |
| args_obj = vars(args) | |
| main(args_obj) | |