File size: 5,756 Bytes
eb5b895 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import os
import sys
import time
import glob
import numpy as np
import utils
from PIL import Image
import logging
import argparse
import torch.utils
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
from model import *
from multi_read_data import DataLoader
parser = argparse.ArgumentParser("ZERO-IG")
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument('--cuda', default=True, type=bool, help='Use CUDA to train model')
parser.add_argument('--gpu', type=str, default='0', help='gpu device id')
parser.add_argument('--seed', type=int, default=2, help='random seed')
parser.add_argument('--epochs', type=int, default=2001, help='epochs')
parser.add_argument('--lr', type=float, default=0.0003, help='learning rate')
parser.add_argument('--save', type=str, default='./EXP/', help='location of the data corpus')
parser.add_argument('--model_pretrain', type=str,default='',help='location of the data corpus')
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
args.save = args.save + '/' + 'Train-{}'.format(time.strftime("%Y%m%d-%H%M%S"))
utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
model_path = args.save + '/model_epochs/'
os.makedirs(model_path, exist_ok=True)
image_path = args.save + '/image_epochs/'
os.makedirs(image_path, exist_ok=True)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
fh.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(fh)
logging.info("train file name = %s", os.path.split(__file__))
if torch.cuda.is_available():
if args.cuda:
torch.set_default_tensor_type('torch.cuda.FloatTensor')
if not args.cuda:
print("WARNING: It looks like you have a CUDA device, but aren't " +
"using CUDA.\nRun with --cuda for optimal training speed.")
torch.set_default_tensor_type('torch.FloatTensor')
else:
torch.set_default_tensor_type('torch.FloatTensor')
def save_images(tensor):
image_numpy = tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
im = np.clip(image_numpy * 255.0, 0, 255.0).astype('uint8')
return im
def main():
if not torch.cuda.is_available():
logging.info('no gpu device available')
sys.exit(1)
np.random.seed(args.seed)
cudnn.benchmark = True
torch.manual_seed(args.seed)
cudnn.enabled = True
torch.cuda.manual_seed(args.seed)
logging.info('gpu device = %s' % args.gpu)
logging.info("args = %s", args)
model =Network()
utils.save(model, os.path.join(args.save, 'initial_weights.pt'))
model.enhance.in_conv.apply(model.enhance_weights_init)
model.enhance.conv.apply(model.enhance_weights_init)
model.enhance.out_conv.apply(model.enhance_weights_init)
model = model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=3e-4)
MB = utils.count_parameters_in_MB(model)
logging.info("model size = %f", MB)
print(MB)
train_low_data_names = './data/1'
TrainDataset = DataLoader(img_dir=train_low_data_names, task='train')
test_low_data_names = './data/1'
TestDataset = DataLoader(img_dir=test_low_data_names, task='test')
train_queue = torch.utils.data.DataLoader(
TrainDataset, batch_size=args.batch_size,
pin_memory=True, num_workers=0, shuffle=False, generator=torch.Generator(device='cuda'))
test_queue = torch.utils.data.DataLoader(
TestDataset, batch_size=1,
pin_memory=True, num_workers=0, shuffle=False, generator=torch.Generator(device='cuda'))
total_step = 0
model.train()
for epoch in range(args.epochs):
losses = []
for idx, (input, img_name) in enumerate(train_queue):
total_step += 1
input = Variable(input, requires_grad=False).cuda()
optimizer.zero_grad()
optimizer.param_groups[0]['capturable'] = True
loss = model._loss(input)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 5)
optimizer.step()
losses.append(loss.item())
logging.info('train-epoch %03d %03d %f', epoch, idx, loss)
logging.info('train-epoch %03d %f', epoch, np.average(losses))
utils.save(model, os.path.join(model_path, 'weights_%d.pt' % epoch))
if epoch % 50 == 0 and total_step != 0:
model.eval()
with torch.no_grad():
for idx, (input, img_name) in enumerate(test_queue):
input = Variable(input, volatile=True).cuda()
image_name = img_name[0].split('/')[-1].split('.')[0]
L_pred1,L_pred2,L2,s2,s21,s22,H2,H11,H12,H13,s13,H14,s14,H3,s3,H3_pred,H4_pred,L_pred1_L_pred2_diff,H13_H14_diff,H2_blur,H3_blur= model(input)
input_name = '%s' % (image_name)
H3 = save_images(H3)
H2= save_images(H2)
os.makedirs(args.save + '/result/denoise/', exist_ok=True)
os.makedirs(args.save + '/result/enhance/', exist_ok=True)
Image.fromarray(H3).save(args.save + '/result/denoise/' + input_name+'_denoise_'+str(epoch)+'.png', 'PNG')
Image.fromarray(H2).save(args.save + '/result/enhance/' +input_name+'_enhance_'+str(epoch)+'.png', 'PNG')
if __name__ == '__main__':
main() |