|
import tensorflow as tf |
|
import h5py |
|
import os |
|
import fnmatch |
|
import shutil |
|
from tqdm import tqdm |
|
from multiprocessing import Pool |
|
import numpy as np |
|
|
|
|
|
def _bytes_feature(value): |
|
"""Returns a bytes_list from a string / byte.""" |
|
if isinstance(value, type(tf.constant(0))): |
|
value = value.numpy() |
|
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) |
|
|
|
|
|
def _bool_feature(value): |
|
"""Returns a bool_list from a boolean.""" |
|
return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)])) |
|
|
|
|
|
def serialize_example(action, base_action, qpos, qvel, cam_high, cam_left_wrist, cam_right_wrist, instruction, terminate_episode): |
|
feature = { |
|
'action': _bytes_feature(tf.io.serialize_tensor(action)), |
|
'base_action': _bytes_feature(tf.io.serialize_tensor(base_action)), |
|
'qpos': _bytes_feature(tf.io.serialize_tensor(qpos)), |
|
'qvel': _bytes_feature(tf.io.serialize_tensor(qvel)), |
|
'cam_high': _bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_high.tobytes(), dtype=tf.string))), |
|
'cam_left_wrist': _bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_left_wrist.tobytes(), dtype=tf.string))), |
|
'cam_right_wrist': _bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_right_wrist.tobytes(), dtype=tf.string))), |
|
'instruction': _bytes_feature(instruction), |
|
'terminate_episode': _bool_feature(terminate_episode) |
|
} |
|
example_proto = tf.train.Example(features=tf.train.Features(feature=feature)) |
|
return example_proto.SerializeToString() |
|
|
|
|
|
def process_hdf5_file(args): |
|
filepath, root_dir, out_dir = args |
|
output_dir = os.path.join(out_dir, os.path.relpath(os.path.dirname(filepath), root_dir)) |
|
os.makedirs(output_dir, exist_ok=True) |
|
filename = os.path.basename(filepath) |
|
tfrecord_path = os.path.join(output_dir, filename.replace('.hdf5', '.tfrecord')) |
|
|
|
if os.path.exists(tfrecord_path) and os.path.getsize(tfrecord_path) > 0: |
|
return f"TFRecords already exist at {tfrecord_path}" |
|
try: |
|
with h5py.File(filepath, 'r') as f, tf.io.TFRecordWriter(tfrecord_path) as writer: |
|
num_episodes = f['action'].shape[0] |
|
|
|
EPS = 1e-2 |
|
qpos = f['observations']['qpos'][:] |
|
|
|
qpos_delta = np.abs(qpos - qpos[0:1]) |
|
indices = np.where(np.any(qpos_delta > EPS, axis=1))[0] |
|
if len(indices) > 0: |
|
first_idx = indices[0] |
|
else: |
|
raise ValueError("Found no qpos that exceeds the threshold.") |
|
|
|
for i in range(first_idx-1, num_episodes): |
|
action = f['action'][i] |
|
base_action = f['base_action'][i] |
|
qpos = f['observations']['qpos'][i] |
|
qvel = f['observations']['qvel'][i] |
|
cam_high = f['observations']['images']['cam_high'][i] |
|
cam_left_wrist = f['observations']['images']['cam_left_wrist'][i] |
|
cam_right_wrist = f['observations']['images']['cam_right_wrist'][i] |
|
instruction = f['instruction'][()] |
|
terminate_episode = i == num_episodes - 1 |
|
serialized_example = serialize_example(action, base_action, qpos, qvel, cam_high, cam_left_wrist, cam_right_wrist, instruction, terminate_episode) |
|
writer.write(serialized_example) |
|
except Exception as e: |
|
with open("error_log.txt", "a") as f: |
|
f.write(f"{filepath}\n") |
|
print(f"error at {filepath}: {e}") |
|
return f"TFRecords written to {tfrecord_path}" |
|
|
|
|
|
def write_tfrecords(root_dir, out_dir): |
|
if not os.path.exists(out_dir): |
|
os.makedirs(out_dir) |
|
|
|
hdf5_files = [] |
|
for root, dirs, files in os.walk(root_dir): |
|
if os.path.exists(os.path.join(root,"expanded_instruction_gpt-4-turbo.json")): |
|
|
|
target_path = os.path.join(out_dir, os.path.relpath(root, root_dir)) |
|
os.makedirs(target_path, exist_ok=True) |
|
shutil.copy(os.path.join(root,"expanded_instruction_gpt-4-turbo.json"), target_path) |
|
elif os.path.exists(os.path.join(root,"expanded_instruction.json")): |
|
print(root) |
|
target_path = os.path.join(out_dir, os.path.relpath(root, root_dir)) |
|
os.makedirs(target_path, exist_ok=True) |
|
shutil.copy(os.path.join(root,"expanded_instruction.json"), target_path) |
|
|
|
os.rename(os.path.join(out_dir, os.path.relpath(root, root_dir), "expanded_instruction.json"), os.path.join(out_dir, os.path.relpath(root, root_dir), "expanded_instruction_gpt-4-turbo.json")) |
|
for filename in fnmatch.filter(files, '*.hdf5'): |
|
filepath = os.path.join(root, filename) |
|
hdf5_files.append((filepath, root_dir, out_dir)) |
|
|
|
with Pool(16) as pool: |
|
max_count = len(hdf5_files) |
|
with tqdm(total=max_count) as pbar: |
|
for _ in pool.imap_unordered(process_hdf5_file, hdf5_files): |
|
pbar.update(1) |
|
|
|
print(f"TFRecords written to {out_dir}") |
|
|
|
|
|
root_dir = "../datasets/agilex/rdt_data/" |
|
out_dir = "../datasets/agilex/tfrecords/" |
|
write_tfrecords(root_dir, out_dir) |
|
|