|
import os
|
|
import sys
|
|
import time
|
|
import glob
|
|
import numpy as np
|
|
import utils
|
|
from PIL import Image
|
|
import logging
|
|
import argparse
|
|
import torch.utils
|
|
import torch.backends.cudnn as cudnn
|
|
from torch.autograd import Variable
|
|
from model import *
|
|
from multi_read_data import DataLoader
|
|
|
|
|
|
parser = argparse.ArgumentParser("ZERO-IG")
|
|
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
|
|
parser.add_argument('--cuda', default=True, type=bool, help='Use CUDA to train model')
|
|
parser.add_argument('--gpu', type=str, default='0', help='gpu device id')
|
|
parser.add_argument('--seed', type=int, default=2, help='random seed')
|
|
parser.add_argument('--epochs', type=int, default=2001, help='epochs')
|
|
parser.add_argument('--lr', type=float, default=0.0003, help='learning rate')
|
|
parser.add_argument('--save', type=str, default='./EXP/', help='location of the data corpus')
|
|
parser.add_argument('--model_pretrain', type=str,default='',help='location of the data corpus')
|
|
|
|
args = parser.parse_args()
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
|
|
|
args.save = args.save + '/' + 'Train-{}'.format(time.strftime("%Y%m%d-%H%M%S"))
|
|
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
|
|
model_path = args.save + '/model_epochs/'
|
|
os.makedirs(model_path, exist_ok=True)
|
|
image_path = args.save + '/image_epochs/'
|
|
os.makedirs(image_path, exist_ok=True)
|
|
|
|
log_format = '%(asctime)s %(message)s'
|
|
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
|
|
format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
|
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
|
|
fh.setFormatter(logging.Formatter(log_format))
|
|
logging.getLogger().addHandler(fh)
|
|
|
|
logging.info("train file name = %s", os.path.split(__file__))
|
|
|
|
if torch.cuda.is_available():
|
|
if args.cuda:
|
|
torch.set_default_tensor_type('torch.cuda.FloatTensor')
|
|
if not args.cuda:
|
|
print("WARNING: It looks like you have a CUDA device, but aren't " +
|
|
"using CUDA.\nRun with --cuda for optimal training speed.")
|
|
torch.set_default_tensor_type('torch.FloatTensor')
|
|
else:
|
|
torch.set_default_tensor_type('torch.FloatTensor')
|
|
|
|
|
|
def save_images(tensor):
|
|
image_numpy = tensor[0].cpu().float().numpy()
|
|
image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
|
|
im = np.clip(image_numpy * 255.0, 0, 255.0).astype('uint8')
|
|
return im
|
|
|
|
|
|
def main():
|
|
if not torch.cuda.is_available():
|
|
logging.info('no gpu device available')
|
|
sys.exit(1)
|
|
|
|
np.random.seed(args.seed)
|
|
cudnn.benchmark = True
|
|
torch.manual_seed(args.seed)
|
|
cudnn.enabled = True
|
|
torch.cuda.manual_seed(args.seed)
|
|
logging.info('gpu device = %s' % args.gpu)
|
|
logging.info("args = %s", args)
|
|
|
|
|
|
|
|
model =Network()
|
|
utils.save(model, os.path.join(args.save, 'initial_weights.pt'))
|
|
model.enhance.in_conv.apply(model.enhance_weights_init)
|
|
model.enhance.conv.apply(model.enhance_weights_init)
|
|
model.enhance.out_conv.apply(model.enhance_weights_init)
|
|
model = model.cuda()
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=3e-4)
|
|
MB = utils.count_parameters_in_MB(model)
|
|
logging.info("model size = %f", MB)
|
|
print(MB)
|
|
train_low_data_names = './data/1'
|
|
TrainDataset = DataLoader(img_dir=train_low_data_names, task='train')
|
|
|
|
test_low_data_names = './data/1'
|
|
TestDataset = DataLoader(img_dir=test_low_data_names, task='test')
|
|
|
|
train_queue = torch.utils.data.DataLoader(
|
|
TrainDataset, batch_size=args.batch_size,
|
|
pin_memory=True, num_workers=0, shuffle=False, generator=torch.Generator(device='cuda'))
|
|
test_queue = torch.utils.data.DataLoader(
|
|
TestDataset, batch_size=1,
|
|
pin_memory=True, num_workers=0, shuffle=False, generator=torch.Generator(device='cuda'))
|
|
|
|
total_step = 0
|
|
model.train()
|
|
for epoch in range(args.epochs):
|
|
losses = []
|
|
for idx, (input, img_name) in enumerate(train_queue):
|
|
total_step += 1
|
|
input = Variable(input, requires_grad=False).cuda()
|
|
optimizer.zero_grad()
|
|
optimizer.param_groups[0]['capturable'] = True
|
|
loss = model._loss(input)
|
|
loss.backward()
|
|
nn.utils.clip_grad_norm_(model.parameters(), 5)
|
|
optimizer.step()
|
|
losses.append(loss.item())
|
|
logging.info('train-epoch %03d %03d %f', epoch, idx, loss)
|
|
logging.info('train-epoch %03d %f', epoch, np.average(losses))
|
|
utils.save(model, os.path.join(model_path, 'weights_%d.pt' % epoch))
|
|
|
|
if epoch % 50 == 0 and total_step != 0:
|
|
model.eval()
|
|
with torch.no_grad():
|
|
for idx, (input, img_name) in enumerate(test_queue):
|
|
input = Variable(input, volatile=True).cuda()
|
|
image_name = img_name[0].split('/')[-1].split('.')[0]
|
|
L_pred1,L_pred2,L2,s2,s21,s22,H2,H11,H12,H13,s13,H14,s14,H3,s3,H3_pred,H4_pred,L_pred1_L_pred2_diff,H13_H14_diff,H2_blur,H3_blur= model(input)
|
|
input_name = '%s' % (image_name)
|
|
H3 = save_images(H3)
|
|
H2= save_images(H2)
|
|
os.makedirs(args.save + '/result/denoise/', exist_ok=True)
|
|
os.makedirs(args.save + '/result/enhance/', exist_ok=True)
|
|
Image.fromarray(H3).save(args.save + '/result/denoise/' + input_name+'_denoise_'+str(epoch)+'.png', 'PNG')
|
|
Image.fromarray(H2).save(args.save + '/result/enhance/' +input_name+'_enhance_'+str(epoch)+'.png', 'PNG')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main() |