euijinrnd's picture
Add files using upload-large-folder tool
9de9fbf verified
import json
import random
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import yaml
from data.episode_transform import process_episode, flatten_episode, \
flatten_episode_agilex, bgr_to_rgb
from data.utils import dataset_to_path
from data.preprocess_scripts import *
# Producer does not need GPU
tf.config.set_visible_devices([], 'GPU')
OPENX_EMBOD_DIR = 'data/datasets/openx_embod'
DATASET_NAMES_NOOPENX = [
"aloha_mobile",
"aloha_static",
"roboset",
"agilex",
"rh20t",
'calvin',
"bridgev2"
]
# Read the config
with open('configs/base.yaml', 'r') as file:
config = yaml.safe_load(file)
# Load some constants from the config
EPSD_LEN_THRESH_LOW = config['dataset']['epsd_len_thresh_low']
EPSD_LEN_THRESH_HIGH = config['dataset']['epsd_len_thresh_high']
# Read the image keys of each dataset
with open('configs/dataset_img_keys.json', 'r') as file:
IMAGE_KEYS = json.load(file)
class VLADataset:
"""
This class is used to sample episodes from the embododiment dataset.
"""
def __init__(self, seed, dataset_type, repeat=True):
'''
seed: the random seed
dataset_type: 'pretrain' or 'finetune', which dataset to load
repeat: whether to repeat to infinite length
'''
dataset_names_cfg = 'configs/pretrain_datasets.json' \
if dataset_type == "pretrain" else 'configs/finetune_datasets.json'
with open(dataset_names_cfg, 'r') as file:
DATASET_NAMES = json.load(file)
self.dataset_names = DATASET_NAMES
sample_weights_cfg = 'configs/pretrain_sample_weights.json' \
if dataset_type == "pretrain" else 'configs/finetune_sample_weights.json'
# Load the sample weights
with open(sample_weights_cfg, 'r') as file:
SAMPLE_WEIGHTS = json.load(file)
self.openx_dir = OPENX_EMBOD_DIR
self.epsd_len_thresh_low = EPSD_LEN_THRESH_LOW
self.epsd_len_thresh_high = EPSD_LEN_THRESH_HIGH
self.repeat = repeat
# Set the random seed
tf.random.set_seed(seed)
np.random.seed(seed)
# Weights of the each dataset in the collection to sample from
sample_weights = []
self.name2dataset = {}
for dataset_name in self.dataset_names:
if dataset_name in DATASET_NAMES_NOOPENX:
dataset = globals()[dataset_name].load_dataset(seed)
else:
dataset_path = dataset_to_path(dataset_name, self.openx_dir)
dataset = tfds.builder_from_directory(builder_dir=dataset_path)
dataset = dataset.as_dataset(split='all', shuffle_files=True)
# You can add filter for other datasets
if dataset_name == 'kuka':
dataset = dataset.filter(
lambda x: x['success'])
elif dataset_name == 'bc_z':
dataset = dataset.filter(
lambda x: tf.math.greater(
next(iter(x['steps']))['observation']['episode_success'], 0.5))
elif dataset_name == 'ucsd_pick_and_place_dataset_converted_externally_to_rlds':
dataset = dataset.filter(
lambda x: x['episode_metadata']['success'])
elif dataset_name == 'utokyo_xarm_bimanual_converted_externally_to_rlds':
# Only preserve the meaningful episodes
dataset = dataset.filter(
lambda x: tf.math.equal(
next(iter(x['steps']))['language_instruction'],
tf.constant('Unfold a wrinkled towel.')))
# Note: use cache() will cause the unexpected crash
# dataset = dataset.map().cache().shuffle().repeat()
print(dataset_name)
dataset = dataset\
.map(
lambda x: process_episode(x, dataset_name,
IMAGE_KEYS[dataset_name]['image_keys'],
IMAGE_KEYS[dataset_name]['image_mask'])
)
# Change BGR to RGB if needed
if dataset_name == 'fmb':
dataset = dataset.map(bgr_to_rgb)
if self.repeat:
dataset = dataset.repeat()
self.name2dataset[dataset_name] = iter(dataset)
print(SAMPLE_WEIGHTS)
sample_weights.append(SAMPLE_WEIGHTS[dataset_name])
# Normalize the sample weights
sample_weights = np.array(sample_weights)
self.sample_weights = sample_weights / np.sum(sample_weights)
def __iter__(self):
'''
Sample batches of episodes for an epoch.
'''
while True:
dataset_name = np.random.choice(self.dataset_names, p=self.sample_weights)
episode = next(self.name2dataset[dataset_name])
if dataset_name == "agilex":
episode_steps = flatten_episode_agilex(episode)
else:
episode_steps = flatten_episode(episode)
# Filter too short
if len(episode_steps) < self.epsd_len_thresh_low:
continue
# Randomly sample too long
if len(episode_steps) > self.epsd_len_thresh_high:
episode_steps = random.sample(episode_steps, self.epsd_len_thresh_high)
yield episode_steps
if __name__ == "__main__":
dataset = VLADataset(0, 'finetune')
for episode in dataset:
print(episode[0])
break