newTryOn / models /STAR /trainer.py
amanSethSmava
new commit
6d314be
import os
import sys
import time
import argparse
import traceback
import torch
import torch.nn as nn
from lib import utility
from lib.utils import AverageMeter, convert_secs2time
os.environ["MKL_THREADING_LAYER"] = "GNU"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
def train(args):
device_ids = args.device_ids
nprocs = len(device_ids)
if nprocs > 1:
torch.multiprocessing.spawn(
train_worker, args=(nprocs, 1, args), nprocs=nprocs,
join=True)
elif nprocs == 1:
train_worker(device_ids[0], nprocs, 1, args)
else:
assert False
def train_worker(world_rank, world_size, nodes_size, args):
# initialize config.
config = utility.get_config(args)
config.device_id = world_rank if nodes_size == 1 else world_rank % torch.cuda.device_count()
# set environment
utility.set_environment(config)
# initialize instances, such as writer, logger and wandb.
if world_rank == 0:
config.init_instance()
if config.logger is not None:
config.logger.info("\n" + "\n".join(["%s: %s" % item for item in config.__dict__.items()]))
config.logger.info("Loaded configure file %s: %s" % (config.type, config.id))
# worker communication
if world_size > 1:
torch.distributed.init_process_group(
backend="nccl", init_method="tcp://localhost:23456" if nodes_size == 1 else "env://",
rank=world_rank, world_size=world_size)
torch.cuda.set_device(config.device)
# model
net = utility.get_net(config)
if world_size > 1:
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
net = net.float().to(config.device)
net.train(True)
if config.ema and world_rank == 0:
net_ema = utility.get_net(config)
if world_size > 1:
net_ema = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net_ema)
net_ema = net_ema.float().to(config.device)
net_ema.eval()
utility.accumulate_net(net_ema, net, 0)
else:
net_ema = None
# multi-GPU training
if world_size > 1:
net_module = nn.parallel.DistributedDataParallel(net, device_ids=[config.device_id],
output_device=config.device_id, find_unused_parameters=True)
else:
net_module = net
criterions = utility.get_criterions(config)
optimizer = utility.get_optimizer(config, net_module)
scheduler = utility.get_scheduler(config, optimizer)
# load pretrain model
if args.pretrained_weight is not None:
if not os.path.exists(args.pretrained_weight):
pretrained_weight = os.path.join(config.work_dir, args.pretrained_weight)
else:
pretrained_weight = args.pretrained_weight
try:
checkpoint = torch.load(pretrained_weight)
net.load_state_dict(checkpoint["net"], strict=False)
if net_ema is not None:
net_ema.load_state_dict(checkpoint["net_ema"], strict=False)
if config.logger is not None:
config.logger.warn("Successed to load pretrain model %s." % pretrained_weight)
start_epoch = checkpoint["epoch"]
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
except:
start_epoch = 0
if config.logger is not None:
config.logger.warn("Failed to load pretrain model %s." % pretrained_weight)
else:
start_epoch = 0
if config.logger is not None:
config.logger.info("Loaded network")
# data - train, val
train_loader = utility.get_dataloader(config, "train", world_rank, world_size)
if world_rank == 0:
val_loader = utility.get_dataloader(config, "val")
if config.logger is not None:
config.logger.info("Loaded data")
# forward & backward
if config.logger is not None:
config.logger.info("Optimizer type %s. Start training..." % (config.optimizer))
if not os.path.exists(config.model_dir) and world_rank == 0:
os.makedirs(config.model_dir)
# training
best_metric, best_net = None, None
epoch_time, eval_time = AverageMeter(), AverageMeter()
for i_epoch, epoch in enumerate(range(config.max_epoch + 1)):
try:
epoch_start_time = time.time()
if epoch >= start_epoch:
# forward and backward
if epoch != start_epoch:
utility.forward_backward(config, train_loader, net_module, net, net_ema, criterions, optimizer,
epoch)
if world_size > 1:
torch.distributed.barrier()
# validating
if epoch % config.val_epoch == 0 and epoch != 0 and world_rank == 0:
eval_start_time = time.time()
epoch_nets = {"net": net, "net_ema": net_ema}
for net_name, epoch_net in epoch_nets.items():
if epoch_net is None:
continue
result, metrics = utility.forward(config, val_loader, epoch_net)
for k, metric in enumerate(metrics):
if config.logger is not None and len(metric) != 0:
config.logger.info(
"Val_{}/Metric{:3d} in this epoch: [NME {:.6f}, FR {:.6f}, AUC {:.6f}]".format(
net_name, k, metric[0], metric[1], metric[2]))
# update best model.
cur_metric = metrics[config.key_metric_index][0]
if best_metric is None or best_metric > cur_metric:
best_metric = cur_metric
best_net = epoch_net
current_pytorch_model_path = os.path.join(config.model_dir, "best_model.pkl")
# current_onnx_model_path = os.path.join(config.model_dir, "train.onnx")
utility.save_model(
config,
epoch,
best_net,
net_ema,
optimizer,
scheduler,
current_pytorch_model_path)
if best_metric is not None:
config.logger.info(
"Val/Best_Metric%03d in this epoch: %.6f" % (config.key_metric_index, best_metric))
eval_time.update(time.time() - eval_start_time)
# saving model
if epoch == config.max_epoch and world_rank == 0:
current_pytorch_model_path = os.path.join(config.model_dir, "last_model.pkl")
# current_onnx_model_path = os.path.join(config.model_dir, "model_epoch_%s.onnx" % epoch)
utility.save_model(
config,
epoch,
net,
net_ema,
optimizer,
scheduler,
current_pytorch_model_path)
if world_size > 1:
torch.distributed.barrier()
# adjusting learning rate
if epoch > 0:
scheduler.step()
epoch_time.update(time.time() - epoch_start_time)
last_time = convert_secs2time(epoch_time.avg * (config.max_epoch - i_epoch), True)
if config.logger is not None:
config.logger.info(
"Train/Epoch: %d/%d, Learning rate decays to %s, " % (
epoch, config.max_epoch, str(scheduler.get_last_lr())) \
+ last_time + 'eval_time: {:4.2f}, '.format(eval_time.avg) + '\n\n')
except:
traceback.print_exc()
config.logger.error("Exception happened in training steps")
if config.logger is not None:
config.logger.info("Training finished")
try:
if config.logger is not None and best_metric is not None:
new_folder_name = config.folder + '-fin-{:.4f}'.format(best_metric)
new_work_dir = os.path.join(config.ckpt_dir, config.data_definition, new_folder_name)
os.system('mv {} {}'.format(config.work_dir, new_work_dir))
except:
traceback.print_exc()
if world_size > 1:
torch.distributed.destroy_process_group()