ZeroIG / test.py
syedaoon's picture
Upload 7 files
eb5b895 verified
import os
import sys
import numpy as np
import torch
import argparse
import logging
import torch.utils
from PIL import Image
from torch.autograd import Variable
from model import Finetunemodel
from multi_read_data import DataLoader
from thop import profile
root_dir = os.path.abspath('../')
sys.path.append(root_dir)
parser = argparse.ArgumentParser("ZERO-IG")
parser.add_argument('--data_path_test_low', type=str, default='./data',
help='location of the data corpus')
parser.add_argument('--save', type=str,
default='./results/',
help='location of the data corpus')
parser.add_argument('--model_test', type=str,
default='./model',
help='location of the data corpus')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--seed', type=int, default=2, help='random seed')
args = parser.parse_args()
save_path = args.save
os.makedirs(save_path, exist_ok=True)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
format=log_format, datefmt='%m/%d %I:%M:%S %p')
mertic = logging.FileHandler(os.path.join(args.save, 'log.txt'))
mertic.setFormatter(logging.Formatter(log_format))
logging.getLogger().addHandler(mertic)
logging.info("train file name = %s", os.path.split(__file__))
TestDataset = DataLoader(img_dir=args.data_path_test_low,task='test')
test_queue = torch.utils.data.DataLoader(TestDataset, batch_size=1, pin_memory=True, num_workers=0, shuffle=False)
def save_images(tensor):
image_numpy = tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
im = np.clip(image_numpy * 255.0, 0, 255.0).astype('uint8')
return im
def calculate_model_parameters(model):
return sum(p.numel() for p in model.parameters())
def calculate_model_flops(model, input_tensor):
flops, _ = profile(model, inputs=(input_tensor,))
flops_in_gigaflops = flops / 1e9 # Convert FLOPs to gigaflops (G)
return flops_in_gigaflops
def main():
if not torch.cuda.is_available():
print('no gpu device available')
sys.exit(1)
model = Finetunemodel(args.model_test)
model = model.cuda()
model.eval()
# Calculate model size
total_params = calculate_model_parameters(model)
print("Total number of parameters: ", total_params)
for p in model.parameters():
p.requires_grad = False
with torch.no_grad():
for _, (input, img_name) in enumerate(test_queue):
input = Variable(input, volatile=True).cuda()
input_name = img_name[0].split('/')[-1].split('.')[0]
enhance,output = model(input)
input_name = '%s' % (input_name)
enhance=save_images(enhance)
output = save_images(output)
os.makedirs(args.save + '/result', exist_ok=True)
Image.fromarray(output).save(args.save + '/result/' +input_name + '_denoise' + '.png', 'PNG')
Image.fromarray(enhance).save(args.save + '/result/'+ input_name + '_enhance' + '.png', 'PNG')
torch.set_grad_enabled(True)
if __name__ == '__main__':
main()