ZeroIG / train.py
syedaoon's picture
Upload 7 files
eb5b895 verified
import os
import sys
import time
import glob
import numpy as np
import utils
from PIL import Image
import logging
import argparse
import torch.utils
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from model import *
from multi_read_data import DataLoader
parser = argparse.ArgumentParser("ZERO-IG")
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument('--cuda', default=True, type=bool, help='Use CUDA to train model')
parser.add_argument('--gpu', type=str, default='0', help='gpu device id')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--epochs', type=int, default=2001, help='epochs')
parser.add_argument('--lr', type=float, default=0.0003, help='learning rate')
parser.add_argument('--save', type=str, default='./EXP/', help='location of the data corpus')
parser.add_argument('--model_pretrain', type=str,default='',help='location of the data corpus')
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
args.save = args.save + '/' + 'Train-{}'.format(time.strftime("%Y%m%d-%H%M%S"))
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
model_path = args.save + '/model_epochs/'
os.makedirs(model_path, exist_ok=True)
image_path = args.save + '/image_epochs/'
os.makedirs(image_path, exist_ok=True)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
logging.info("train file name = %s", os.path.split(__file__))
if torch.cuda.is_available():
if args.cuda:
torch.set_default_tensor_type('torch.cuda.FloatTensor')
if not args.cuda:
print("WARNING: It looks like you have a CUDA device, but aren't " +
"using CUDA.\nRun with --cuda for optimal training speed.")
torch.set_default_tensor_type('torch.FloatTensor')
else:
torch.set_default_tensor_type('torch.FloatTensor')
def save_images(tensor):
image_numpy = tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
im = np.clip(image_numpy * 255.0, 0, 255.0).astype('uint8')
return im
def main():
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)
np.random.seed(args.seed)
cudnn.benchmark = True
torch.manual_seed(args.seed)
cudnn.enabled = True
torch.cuda.manual_seed(args.seed)
logging.info('gpu device = %s' % args.gpu)
logging.info("args = %s", args)
model =Network()
utils.save(model, os.path.join(args.save, 'initial_weights.pt'))
model.enhance.in_conv.apply(model.enhance_weights_init)
model.enhance.conv.apply(model.enhance_weights_init)
model.enhance.out_conv.apply(model.enhance_weights_init)
model = model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=3e-4)
MB = utils.count_parameters_in_MB(model)
logging.info("model size = %f", MB)
print(MB)
train_low_data_names = './data/1'
TrainDataset = DataLoader(img_dir=train_low_data_names, task='train')
test_low_data_names = './data/1'
TestDataset = DataLoader(img_dir=test_low_data_names, task='test')
train_queue = torch.utils.data.DataLoader(
TrainDataset, batch_size=args.batch_size,
pin_memory=True, num_workers=0, shuffle=False, generator=torch.Generator(device='cuda'))
test_queue = torch.utils.data.DataLoader(
TestDataset, batch_size=1,
pin_memory=True, num_workers=0, shuffle=False, generator=torch.Generator(device='cuda'))
total_step = 0
model.train()
for epoch in range(args.epochs):
losses = []
for idx, (input, img_name) in enumerate(train_queue):
total_step += 1
input = Variable(input, requires_grad=False).cuda()
optimizer.zero_grad()
optimizer.param_groups[0]['capturable'] = True
loss = model._loss(input)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 5)
optimizer.step()
losses.append(loss.item())
logging.info('train-epoch %03d %03d %f', epoch, idx, loss)
logging.info('train-epoch %03d %f', epoch, np.average(losses))
utils.save(model, os.path.join(model_path, 'weights_%d.pt' % epoch))
if epoch % 50 == 0 and total_step != 0:
model.eval()
with torch.no_grad():
for idx, (input, img_name) in enumerate(test_queue):
input = Variable(input, volatile=True).cuda()
image_name = img_name[0].split('/')[-1].split('.')[0]
L_pred1,L_pred2,L2,s2,s21,s22,H2,H11,H12,H13,s13,H14,s14,H3,s3,H3_pred,H4_pred,L_pred1_L_pred2_diff,H13_H14_diff,H2_blur,H3_blur= model(input)
input_name = '%s' % (image_name)
H3 = save_images(H3)
H2= save_images(H2)
os.makedirs(args.save + '/result/denoise/', exist_ok=True)
os.makedirs(args.save + '/result/enhance/', exist_ok=True)
Image.fromarray(H3).save(args.save + '/result/denoise/' + input_name+'_denoise_'+str(epoch)+'.png', 'PNG')
Image.fromarray(H2).save(args.save + '/result/enhance/' +input_name+'_enhance_'+str(epoch)+'.png', 'PNG')
if __name__ == '__main__':
main()