|
import numpy as np
|
|
import torch
|
|
import torch.utils.data
|
|
from PIL import Image
|
|
import torchvision.transforms as transforms
|
|
import os
|
|
|
|
class DataLoader(torch.utils.data.Dataset):
|
|
def __init__(self, img_dir, task):
|
|
self.low_img_dir = img_dir
|
|
self.task = task
|
|
self.train_low_data_names = []
|
|
self.train_target_data_names = []
|
|
|
|
for root, dirs, names in os.walk(self.low_img_dir):
|
|
for name in names:
|
|
self.train_low_data_names.append(os.path.join(root, name))
|
|
|
|
self.train_low_data_names.sort()
|
|
self.count = len(self.train_low_data_names)
|
|
transform_list = []
|
|
transform_list += [transforms.ToTensor()]
|
|
self.transform = transforms.Compose(transform_list)
|
|
|
|
|
|
def load_images_transform(self, file):
|
|
|
|
im = Image.open(file).convert('RGB')
|
|
img_norm = self.transform(im).numpy()
|
|
img_norm = np.transpose(img_norm, (1, 2, 0))
|
|
return img_norm
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
low = self.load_images_transform(self.train_low_data_names[index])
|
|
low = np.asarray(low, dtype=np.float32)
|
|
low = np.transpose(low[:, :, :], (2, 0, 1))
|
|
img_name = self.train_low_data_names[index].split('\\')[-1]
|
|
|
|
|
|
|
|
|
|
return torch.from_numpy(low),img_name
|
|
|
|
def __len__(self):
|
|
return self.count
|
|
|