File size: 3,459 Bytes
d899b9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import tensorflow as tf
import os
import numpy as np
from tqdm import tqdm


def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _bool_feature(value):
    """Returns a bool_list from a boolean."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[int(value)]))


def serialize_example(action, robot_obs, rgb_static, rgb_gripper, instruction, terminate_episode):
    # Feature for fixed-length fields
    feature = {
        'action': _bytes_feature(tf.io.serialize_tensor(action)),
        'robot_obs': _bytes_feature(tf.io.serialize_tensor(robot_obs)),
        'rgb_static': _bytes_feature(tf.io.serialize_tensor(rgb_static)),
        'rgb_gripper': _bytes_feature(tf.io.serialize_tensor(rgb_gripper)),
        'terminate_episode': _bool_feature(terminate_episode),
        'instruction': _bytes_feature(instruction),
    }

    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()


def write_tfrecords(root_dir, out_dir):
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    # Get the language annotation and corresponding indices
    f = np.load(os.path.join(root_dir, "lang_annotations/auto_lang_ann.npy"), allow_pickle=True)
    lang = f.item()['language']['ann']
    lang = np.array([x.encode('utf-8') for x in lang])
    lang_start_end_idx = f.item()['info']['indx']
    num_ep = len(lang_start_end_idx)
    
    with tqdm(total=num_ep) as pbar:
        for episode_idx, (start_idx, end_idx) in enumerate(lang_start_end_idx):
            pbar.update(1)

            step_files = [
                f"episode_{str(i).zfill(7)}.npz"
                for i in range(start_idx, end_idx + 1)
            ]
            action = []
            robot_obs = []
            rgb_static = []
            rgb_gripper = []
            instr = lang[episode_idx]
            for step_file in step_files:
                filepath = os.path.join(root_dir, step_file)
                f = np.load(filepath)
                # Get relevent things
                action.append(f['actions'])
                robot_obs.append(f['robot_obs'])
                rgb_static.append(f['rgb_static'])
                rgb_gripper.append(f['rgb_gripper'])
                
            tfrecord_path = os.path.join(out_dir, f'{episode_idx:07d}.tfrecord')
            print(f"Writing TFRecords to {tfrecord_path}")
            with tf.io.TFRecordWriter(tfrecord_path) as writer:
                for i in range(len(step_files)):
                    serialized_example = serialize_example(
                        action[i], robot_obs[i], rgb_static[i], rgb_gripper[i], instr, i == len(step_files) - 1
                    )
                    writer.write(serialized_example)

output_dirs = [
    '../datasets/calvin/tfrecords/training',
    '../datasets/calvin/tfrecords/validation'
]

for output_dir in output_dirs:
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

root_dirs = [
    '../datasets/calvin/task_ABC_D/training',
    '../datasets/calvin/task_ABC_D/validation'
]

for root_dir, output_dir in zip(root_dirs, output_dirs):
    print(f"Writing TFRecords to {output_dir}")
    write_tfrecords(root_dir, output_dir)