Spaces:
Runtime error
Runtime error
| import random | |
| import pickle | |
| import logging | |
| import torch | |
| import cv2 | |
| import os | |
| from torch.utils.data.dataset import Dataset | |
| import numpy as np | |
| import cvbase | |
| from .util.STTN_mask import create_random_shape_with_random_motion | |
| import imageio | |
| from .util.flow_utils import region_fill as rf | |
| logger = logging.getLogger('base') | |
| class VideoBasedDataset(Dataset): | |
| def __init__(self, opt, dataInfo): | |
| self.opt = opt | |
| self.sampleMethod = opt['sample'] | |
| self.dataInfo = dataInfo | |
| self.height, self.width = self.opt['input_resolution'] | |
| self.frame_path = dataInfo['frame_path'] | |
| self.flow_path = dataInfo['flow_path'] # The path of the optical flows | |
| self.train_list = os.listdir(self.frame_path) | |
| self.name2length = self.dataInfo['name2len'] | |
| with open(self.name2length, 'rb') as f: | |
| self.name2length = pickle.load(f) | |
| self.sequenceLen = self.opt['num_frames'] | |
| self.flow2rgb = opt['flow2rgb'] # whether to change flow to rgb domain | |
| self.flow_direction = opt[ | |
| 'flow_direction'] # The direction must be in ['for', 'back', 'bi'], indicating forward, backward and bidirectional flows | |
| def __len__(self): | |
| return len(self.train_list) | |
| def __getitem__(self, idx): | |
| try: | |
| item = self.load_item(idx) | |
| except: | |
| print('Loading error: ' + self.train_list[idx]) | |
| item = self.load_item(0) | |
| return item | |
| def frameSample(self, frameLen, sequenceLen): | |
| if self.sampleMethod == 'random': | |
| indices = [i for i in range(frameLen)] | |
| sampleIndices = random.sample(indices, sequenceLen) | |
| elif self.sampleMethod == 'seq': | |
| pivot = random.randint(0, sequenceLen - 1 - frameLen) | |
| sampleIndices = [i for i in range(pivot, pivot + frameLen)] | |
| else: | |
| raise ValueError('Cannot determine the sample method {}'.format(self.sampleMethod)) | |
| return sampleIndices | |
| def load_item(self, idx): | |
| video = self.train_list[idx] | |
| frame_dir = os.path.join(self.frame_path, video) | |
| forward_flow_dir = os.path.join(self.flow_path, video, 'forward_flo') | |
| backward_flow_dir = os.path.join(self.flow_path, video, 'backward_flo') | |
| frameLen = self.name2length[video] | |
| flowLen = frameLen - 1 | |
| assert frameLen > self.sequenceLen, 'Frame length {} is less than sequence length'.format(frameLen) | |
| sampledIndices = self.frameSample(frameLen, self.sequenceLen) | |
| # generate random masks for these sampled frames | |
| candidateMasks = create_random_shape_with_random_motion(frameLen, 0.9, 1.1, 1, 10) | |
| # read the frames and masks | |
| frames, masks, forward_flows, backward_flows = [], [], [], [] | |
| for i in range(len(sampledIndices)): | |
| frame = self.read_frame(os.path.join(frame_dir, '{:05d}.jpg'.format(sampledIndices[i])), self.height, | |
| self.width) | |
| mask = self.read_mask(candidateMasks[sampledIndices[i]], self.height, self.width) | |
| frames.append(frame) | |
| masks.append(mask) | |
| if self.flow_direction == 'for': | |
| forward_flow = self.read_forward_flow(forward_flow_dir, sampledIndices[i], flowLen) | |
| forward_flow = self.diffusion_flow(forward_flow, mask) | |
| forward_flows.append(forward_flow) | |
| elif self.flow_direction == 'back': | |
| backward_flow = self.read_backward_flow(backward_flow_dir, sampledIndices[i]) | |
| backward_flow = self.diffusion_flow(backward_flow, mask) | |
| backward_flows.append(backward_flow) | |
| elif self.flow_direction == 'bi': | |
| forward_flow = self.read_forward_flow(forward_flow_dir, sampledIndices[i], flowLen) | |
| forward_flow = self.diffusion_flow(forward_flow, mask) | |
| forward_flows.append(forward_flow) | |
| backward_flow = self.read_backward_flow(backward_flow_dir, sampledIndices[i]) | |
| backward_flow = self.diffusion_flow(backward_flow, mask) | |
| backward_flows.append(backward_flow) | |
| else: | |
| raise ValueError('Unknown flow direction mode: {}'.format(self.flow_direction)) | |
| inputs = {'frames': frames, 'masks': masks, 'forward_flo': forward_flows, 'backward_flo': backward_flows} | |
| inputs = self.to_tensor(inputs) | |
| inputs['frames'] = (inputs['frames'] / 255.) * 2 - 1 | |
| return inputs | |
| def diffusion_flow(self, flow, mask): | |
| flow_filled = np.zeros(flow.shape) | |
| flow_filled[:, :, 0] = rf.regionfill(flow[:, :, 0] * (1 - mask), mask) | |
| flow_filled[:, :, 1] = rf.regionfill(flow[:, :, 1] * (1 - mask), mask) | |
| return flow_filled | |
| def read_frame(self, path, height, width): | |
| frame = imageio.imread(path) | |
| frame = cv2.resize(frame, (width, height), cv2.INTER_LINEAR) | |
| return frame | |
| def read_mask(self, mask, height, width): | |
| mask = np.array(mask) | |
| mask = mask / 255. | |
| raw_mask = (mask > 0.5).astype(np.uint8) | |
| raw_mask = cv2.resize(raw_mask, dsize=(width, height), interpolation=cv2.INTER_NEAREST) | |
| return raw_mask | |
| def read_forward_flow(self, forward_flow_dir, sampledIndex, flowLen): | |
| if sampledIndex >= flowLen: | |
| sampledIndex = flowLen - 1 | |
| flow = cvbase.read_flow(os.path.join(forward_flow_dir, '{:05d}.flo'.format(sampledIndex))) | |
| height, width = flow.shape[:2] | |
| flow = cv2.resize(flow, (self.width, self.height), cv2.INTER_LINEAR) | |
| flow[:, :, 0] = flow[:, :, 0] / width * self.width | |
| flow[:, :, 1] = flow[:, :, 1] / height * self.height | |
| return flow | |
| def read_backward_flow(self, backward_flow_dir, sampledIndex): | |
| if sampledIndex == 0: | |
| sampledIndex = 0 | |
| else: | |
| sampledIndex -= 1 | |
| flow = cvbase.read_flow(os.path.join(backward_flow_dir, '{:05d}.flo'.format(sampledIndex))) | |
| height, width = flow.shape[:2] | |
| flow = cv2.resize(flow, (self.width, self.height), cv2.INTER_LINEAR) | |
| flow[:, :, 0] = flow[:, :, 0] / width * self.width | |
| flow[:, :, 1] = flow[:, :, 1] / height * self.height | |
| return flow | |
| def to_tensor(self, data_list): | |
| """ | |
| Args: | |
| data_list: A list contains multiple numpy arrays | |
| Returns: The stacked tensor list | |
| """ | |
| keys = list(data_list.keys()) | |
| for key in keys: | |
| if data_list[key] is None or data_list[key] == []: | |
| data_list.pop(key) | |
| else: | |
| item = data_list[key] | |
| if not isinstance(item, list): | |
| item = torch.from_numpy(np.transpose(item, (2, 0, 1))).float() # [c, h, w] | |
| else: | |
| item = np.stack(item, axis=0) | |
| if len(item.shape) == 3: # [t, h, w] | |
| item = item[:, :, :, np.newaxis] | |
| item = torch.from_numpy(np.transpose(item, (0, 3, 1, 2))).float() # [t, c, h, w] | |
| data_list[key] = item | |
| return data_list | |