File size: 5,439 Bytes
eef26ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import tensorflow as tf
import h5py
import os
import fnmatch
import shutil
from tqdm import tqdm
from multiprocessing import Pool
import numpy as np


def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _bool_feature(value):
    """Returns a bool_list from a boolean."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))


def serialize_example(action, base_action, qpos, qvel, cam_high, cam_left_wrist, cam_right_wrist, instruction, terminate_episode):
    feature = {
        'action': _bytes_feature(tf.io.serialize_tensor(action)),
        'base_action': _bytes_feature(tf.io.serialize_tensor(base_action)),
        'qpos': _bytes_feature(tf.io.serialize_tensor(qpos)),
        'qvel': _bytes_feature(tf.io.serialize_tensor(qvel)),
        'cam_high': _bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_high.tobytes(), dtype=tf.string))),
        'cam_left_wrist': _bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_left_wrist.tobytes(), dtype=tf.string))),
        'cam_right_wrist': _bytes_feature(tf.io.serialize_tensor(tf.convert_to_tensor(cam_right_wrist.tobytes(), dtype=tf.string))),
        'instruction': _bytes_feature(instruction),
        'terminate_episode': _bool_feature(terminate_episode)
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()


def process_hdf5_file(args):
    filepath, root_dir, out_dir = args
    output_dir = os.path.join(out_dir, os.path.relpath(os.path.dirname(filepath), root_dir))
    os.makedirs(output_dir, exist_ok=True)
    filename = os.path.basename(filepath)
    tfrecord_path = os.path.join(output_dir, filename.replace('.hdf5', '.tfrecord'))

    if os.path.exists(tfrecord_path) and os.path.getsize(tfrecord_path) > 0:
        return f"TFRecords already exist at {tfrecord_path}"
    try:
        with h5py.File(filepath, 'r') as f, tf.io.TFRecordWriter(tfrecord_path) as writer:
            num_episodes = f['action'].shape[0]
            # Remove the first few still steps
            EPS = 1e-2
            qpos = f['observations']['qpos'][:]
            # Get the idx of the first qpos whose delta exceeds the threshold
            qpos_delta = np.abs(qpos - qpos[0:1])
            indices = np.where(np.any(qpos_delta > EPS, axis=1))[0]
            if len(indices) > 0:
                first_idx = indices[0]
            else:
                raise ValueError("Found no qpos that exceeds the threshold.")
            
            for i in range(first_idx-1, num_episodes):
                action = f['action'][i]
                base_action = f['base_action'][i]
                qpos = f['observations']['qpos'][i]
                qvel = f['observations']['qvel'][i]
                cam_high = f['observations']['images']['cam_high'][i]
                cam_left_wrist = f['observations']['images']['cam_left_wrist'][i]
                cam_right_wrist = f['observations']['images']['cam_right_wrist'][i]
                instruction  = f['instruction'][()]
                terminate_episode = i == num_episodes - 1
                serialized_example = serialize_example(action, base_action, qpos, qvel, cam_high, cam_left_wrist, cam_right_wrist, instruction, terminate_episode)
                writer.write(serialized_example)
    except Exception as e:
        with open("error_log.txt", "a") as f:
            f.write(f"{filepath}\n")
        print(f"error at {filepath}: {e}")        
    return f"TFRecords written to {tfrecord_path}"


def write_tfrecords(root_dir, out_dir):
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    hdf5_files = []
    for root, dirs, files in os.walk(root_dir):
        if os.path.exists(os.path.join(root,"expanded_instruction_gpt-4-turbo.json")):
            # copy the instruction file
            target_path = os.path.join(out_dir, os.path.relpath(root, root_dir))
            os.makedirs(target_path, exist_ok=True)
            shutil.copy(os.path.join(root,"expanded_instruction_gpt-4-turbo.json"), target_path)
        elif os.path.exists(os.path.join(root,"expanded_instruction.json")):
            print(root)
            target_path = os.path.join(out_dir, os.path.relpath(root, root_dir))
            os.makedirs(target_path, exist_ok=True)
            shutil.copy(os.path.join(root,"expanded_instruction.json"), target_path)
            # rename into expanded_instruction_gpt-4-turbo.json
            os.rename(os.path.join(out_dir, os.path.relpath(root, root_dir), "expanded_instruction.json"), os.path.join(out_dir, os.path.relpath(root, root_dir), "expanded_instruction_gpt-4-turbo.json"))
        for filename in fnmatch.filter(files, '*.hdf5'):
            filepath = os.path.join(root, filename)
            hdf5_files.append((filepath, root_dir, out_dir))

    with Pool(16) as pool:
        max_count = len(hdf5_files)
        with tqdm(total=max_count) as pbar:
            for _ in pool.imap_unordered(process_hdf5_file, hdf5_files):
                pbar.update(1)

    print(f"TFRecords written to {out_dir}")


root_dir = "../datasets/agilex/rdt_data/"
out_dir = "../datasets/agilex/tfrecords/"
write_tfrecords(root_dir, out_dir)