import numpy as np import torch import torch.utils.data from PIL import Image import torchvision.transforms as transforms import os class DataLoader(torch.utils.data.Dataset): def __init__(self, img_dir, task): self.low_img_dir = img_dir self.task = task self.train_low_data_names = [] self.train_target_data_names = [] for root, dirs, names in os.walk(self.low_img_dir): for name in names: self.train_low_data_names.append(os.path.join(root, name)) self.train_low_data_names.sort() self.count = len(self.train_low_data_names) transform_list = [] transform_list += [transforms.ToTensor()] self.transform = transforms.Compose(transform_list) def load_images_transform(self, file): im = Image.open(file).convert('RGB') img_norm = self.transform(im).numpy() img_norm = np.transpose(img_norm, (1, 2, 0)) return img_norm def __getitem__(self, index): low = self.load_images_transform(self.train_low_data_names[index]) low = np.asarray(low, dtype=np.float32) low = np.transpose(low[:, :, :], (2, 0, 1)) img_name = self.train_low_data_names[index].split('\\')[-1] return torch.from_numpy(low),img_name def __len__(self): return self.count