euijinrnd commited on
Commit
9de9fbf
·
verified ·
1 Parent(s): 5142365

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +184 -0
  2. LICENSE +21 -0
  3. README.md +357 -0
  4. Untitled.ipynb +86 -0
  5. data/.gitignore +2 -0
  6. data/compute_dataset_stat.py +240 -0
  7. data/compute_dataset_stat_hdf5.py +100 -0
  8. data/episode_transform.py +406 -0
  9. data/filelock.py +24 -0
  10. data/hdf5_maniskill_dataset.py +243 -0
  11. data/hdf5_vla_dataset.py +533 -0
  12. data/preprocess.py +323 -0
  13. data/preprocess_scripts/__init__.py +73 -0
  14. data/preprocess_scripts/aloha_shoes_table.py +55 -0
  15. data/preprocess_scripts/austin_buds_dataset_converted_externally_to_rlds.py +82 -0
  16. data/preprocess_scripts/berkeley_autolab_ur5.py +95 -0
  17. data/preprocess_scripts/berkeley_cable_routing.py +73 -0
  18. data/preprocess_scripts/berkeley_gnm_sac_son.py +78 -0
  19. data/preprocess_scripts/berkeley_rpt_converted_externally_to_rlds.py +84 -0
  20. data/preprocess_scripts/calvin.py +176 -0
  21. data/preprocess_scripts/cmu_franka_exploration_dataset_converted_externally_to_rlds.py +75 -0
  22. data/preprocess_scripts/cmu_play_fusion.py +82 -0
  23. data/preprocess_scripts/cmu_stretch.py +84 -0
  24. data/preprocess_scripts/droid.py +78 -0
  25. data/preprocess_scripts/fractal20220817_data.py +92 -0
  26. data/preprocess_scripts/iamlab_cmu_pickup_insert_converted_externally_to_rlds.py +80 -0
  27. data/preprocess_scripts/libero_goal_no_noops.py +82 -0
  28. data/preprocess_scripts/libero_spatial_no_noops.py +82 -0
  29. data/preprocess_scripts/nyu_rot_dataset_converted_externally_to_rlds.py +82 -0
  30. data/preprocess_scripts/robo_net.py +71 -0
  31. data/preprocess_scripts/robomimic_lift_ph.py +97 -0
  32. data/preprocess_scripts/robomimic_square_ph.py +97 -0
  33. data/preprocess_scripts/roboset.py +367 -0
  34. data/preprocess_scripts/roboturk.py +77 -0
  35. data/preprocess_scripts/roboturk_real_objectsearch.py +217 -0
  36. data/preprocess_scripts/roboturk_real_towercreation.py +223 -0
  37. data/preprocess_scripts/stanford_hydra_dataset_converted_externally_to_rlds.py +94 -0
  38. data/preprocess_scripts/tokyo_u_lsmo_converted_externally_to_rlds.py +90 -0
  39. data/preprocess_scripts/utokyo_pr2_opening_fridge_converted_externally_to_rlds.py +92 -0
  40. data/preprocess_scripts/utokyo_xarm_bimanual_converted_externally_to_rlds.py +117 -0
  41. data/preprocess_scripts/viola.py +89 -0
  42. data/producer.py +280 -0
  43. data/utils.py +235 -0
  44. data/vla_dataset.py +147 -0
  45. encode_lang.py +60 -0
  46. finetune.sh +57 -0
  47. finetune_maniskill.sh +48 -0
  48. inference.sh +5 -0
  49. main.py +301 -0
  50. models/ema_model.py +89 -0
.gitignore ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ # Some encoder paths
163
+ facebook/
164
+ openai/
165
+ google/
166
+
167
+ # Log
168
+ logs/
169
+
170
+ # Output
171
+ outs/
172
+
173
+ # Checkpoints
174
+ checkpoints/
175
+
176
+ # VSC
177
+ .vscode/
178
+
179
+ # Wandb
180
+ wandb/
181
+
182
+ # Distributed leaning
183
+ hostfile.txt
184
+ .deepspeed_env
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 TSAIL group
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RDT-1B: a Diffusion Foundation Model for Bimanual Manipulation
2
+
3
+ ### 📝[Paper](https://arxiv.org/pdf/2410.07864) | 🌍[Project Page](https://rdt-robotics.github.io/rdt-robotics/) | 🤗[Model](https://huggingface.co/robotics-diffusion-transformer/rdt-1b) | 🛢️[Data](https://huggingface.co/datasets/robotics-diffusion-transformer/rdt-ft-data)
4
+
5
+ ![](./assets/head.png)
6
+
7
+ RDT-1B is a **1B**-parameter (*largest* to date) imitation learning **Diffusion Transformer** pre-trained on **1M+** (*largest* to date) multi-robot episodes. Given language instruction and RGB images of up to three views, RDT can predict the next $64$ robot actions. RDT is inherently compatible with **almost all kinds of modern mobile manipulators**, from single-arm to dual-arm, joint to EEF, position to velocity, and even with wheeled locomotion.
8
+
9
+ We have fine-tuned RDT on **6K+** (one of the *largest*) self-collected bimanual episodes and deployed it on the ALOHA **dual-arm** robot. It has achieved state-of-the-art performance in terms of dexterity, zero-shot generalizability, and few-shot learning. You can find Demo videos on our [project page](https://rdt-robotics.github.io/rdt-robotics/).
10
+
11
+ This repo is an official PyTorch implementation of RDT, containing:
12
+
13
+ - 🛠️Model [implementation](models/rdt_runner.py) of RDT
14
+ - 🤗1M-step [checkpoint](https://huggingface.co/robotics-diffusion-transformer/rdt-1b) of RDT-1B pre-trained on multi-robot data
15
+ - 🤗500K-step [checkpoint](https://huggingface.co/robotics-diffusion-transformer/rdt-170m) of RDT-170M (RDT(small) in [ablation](https://arxiv.org/pdf/2410.07864))
16
+ - 📈Training and sampling [scripts](train/train.py) (with DeepSpeed)
17
+ - 🤖An [example](scripts/agilex_inference.py) of real-robot deployment
18
+ - 🕹️Simulation benchmark from [Maniskill](https://github.com/haosulab/ManiSkill) environment
19
+
20
+ The following guides include the [installation](#installation), [fine-tuning](#fine-tuning-on-your-own-dataset), and [deployment](#deployment-on-real-robots). Please refer to [pre-training](docs/pretrain.md) for a detailed list of pre-training datasets and a pre-training guide.
21
+
22
+ ## 📰 News
23
+ - [2024/12/17] 🔥 [Scripts](#simulation-benchmark) for evaluating RDT in Maniskill Simulation Benchmark is released!
24
+ - [2024/10/23] 🔥 **RDT-170M** (Smaller) model is released, a more VRAM-friendly solution 🚀💻.
25
+
26
+ ## Installation
27
+
28
+ 1. Clone this repo and install prerequisites:
29
+
30
+ ```bash
31
+ # Clone this repo
32
+ git clone [email protected]:thu-ml/RoboticsDiffusionTransformer.git
33
+ cd RoboticsDiffusionTransformer
34
+
35
+ # Create a Conda environment
36
+ conda create -n rdt python=3.10.0
37
+ conda activate rdt
38
+
39
+ # Install pytorch
40
+ # Look up https://pytorch.org/get-started/previous-versions/ with your cuda version for a correct command
41
+ pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121
42
+
43
+ # Install packaging
44
+ pip install packaging==24.0
45
+
46
+ # Install flash-attn
47
+ pip install flash-attn --no-build-isolation
48
+
49
+ # Install other prequisites
50
+ pip install -r requirements.txt
51
+ ```
52
+
53
+ 2. Download off-the-shelf multi-modal encoders:
54
+
55
+ You can download the encoders from the following links:
56
+
57
+ - `t5-v1_1-xxl`: [link](https://huggingface.co/google/t5-v1_1-xxl/tree/main)🤗
58
+ - `siglip`: [link](https://huggingface.co/google/siglip-so400m-patch14-384)🤗
59
+
60
+ And link the encoders to the repo directory:
61
+
62
+ ```bash
63
+ # Under the root directory of this repo
64
+ mkdir -p google
65
+
66
+ # Link the downloaded encoders to this repo
67
+ ln -s /path/to/t5-v1_1-xxl google/t5-v1_1-xxl
68
+ ln -s /path/to/siglip-so400m-patch14-384 google/siglip-so400m-patch14-384
69
+ ```
70
+ 3. Fill the missing argument in [this file](configs/base.yaml#L22):
71
+
72
+ Note that this buffer will only be used during pre-training. See [this doc](docs/pretrain.md) for more details.
73
+ ```
74
+ # ...
75
+
76
+ dataset:
77
+ # ...
78
+ # ADD YOUR buf_path: the path to the buffer (at least 400GB)
79
+ buf_path: /path/to/buffer
80
+ # ...
81
+ ```
82
+
83
+ ## Fine-Tuning on Your Own Dataset
84
+
85
+ If your fine-tuning dataset is in the [Open X-Embodiment](https://robotics-transformer-x.github.io/) or the collection of our pre-training datasets (see [this doc](docs/pretrain.md#download-and-prepare-datasets)), you can also fine-tune RDT through the pre-trained pipeline. You need to remove other redundant datasets in the parameters. We refer to [this guide](docs/pretrain.md) (pre-training).
86
+
87
+ 1. Prepare your dataset:
88
+
89
+ You need to download your dataset to the disk and give it a name `my_cool_dataset`.
90
+
91
+ Then, you can link your dataset to the repo directory:
92
+
93
+ ```bash
94
+ # Under the root directory of this repo
95
+ cd data
96
+ mkdir -p datasets
97
+
98
+ # Link the downloaded dataset to this repo
99
+ ln -s /path/to/my_cool_dataset datasets/my_cool_dataset
100
+ ```
101
+
102
+ 2. Implement the dataset loader:
103
+
104
+ You need to:
105
+
106
+ 1. Register the configuration of `my_cool_dataset`:
107
+
108
+ Append the control frequency of `my_cool_dataset` in [this file](configs/dataset_control_freq.json). Write the name of `my_cool_dataset` in [this file](configs/finetune_datasets.json) and [this file](configs/finetune_sample_weights.json), where the value of the sampling weight doesn't matter since you only have one dataset. In these two files, we leave a placeholder of `agilex`; you can simply replace it with `my_cool_dataset`.
109
+
110
+ 2. Re-Implement the class of `HDF5VLADataset`:
111
+
112
+ You can find this class in [this file](data/hdf5_vla_dataset.py). In this file, we provide an example of loading the fine-tuning dataset used in our paper (see [this link](https://huggingface.co/datasets/robotics-diffusion-transformer/rdt-ft-data)).
113
+
114
+ To adapt it to your dataset, you need to: (a) modify the `HDF5_DIR` (directory to `my_cool_dataset`) and `DATASET_NAME` (should be `"my_cool_dataset"`) in L21 and L22; (b) Implement the two functions of `parse_hdf5_file()` and `parse_hdf5_file_state_only()`. Please take a look at the original file for detailed comments and examples.
115
+
116
+ Note 1: Despite its name, you don't necessarily need to use HDF5 to store your data. Just make sure that the class is correctly implemented.
117
+
118
+ Note 2: During implementation, you may need to fill your robot action into the unified action vector (L180-194). Please refer to [this file](configs/state_vec.py) for an explanation of each element in the unified vector. We have reserved enough slots for each physical quantity. For example, we have reserved ten slots for joint angles. If your robot arm has six degrees of freedom, you only need to fill in the first six.
119
+
120
+ **IMPORTANT 1:** If your robot is single-arm, please fill its action into the *right-arm* portion of the unified action vector, aligning with our pre-training datasets.
121
+
122
+ **IMPORTANT 2:** We use [6D representation](https://arxiv.org/pdf/1812.07035) for EEF rotation. If your action space contains EEF rotation (angle or quaternion), please refer to [this file](docs/test_6drot.py) for conversion. We note that this mapping is not reversible. Different Euler angles may be equivalent and correspond to the same 6D representation.
123
+
124
+ **IMPORTANT 3:** No physical quantities (except the gripper width) are normalized during pre-training. This can preserve each physical quantity's meaning, thereby promoting generalization across robots. Therefore, we encourage you not to normalize any physical quantities but to choose appropriate units for them. Generally, we use the International System of Units, which ensures that most values fall within [-1,1]. As an exception, we perform min-max normalization on the gripper width to [0,1].
125
+
126
+ **IMPORTANT 4:** If you use RTX 4090 (or lower), the GPU memory may be too low to load the `t5-v1_1-xxl` encoder. Instead, we recommend you precompute the language embeddings (see [this file](scripts/encode_lang_batch.py) for an example script) and load them during training. In this way, you need to specify the path to the embeddings in the `HDF5VLADataset` (see L148) rather than the natural language.
127
+
128
+ 3. Compute the dataset statistics information for `my_cool_dataset`:
129
+
130
+ ```bash
131
+ # Under the root directory of this repo
132
+ # Use -h to see the full usage
133
+ python -m data.compute_dataset_stat_hdf5
134
+ ```
135
+
136
+ 3. Start fine-tuning:
137
+
138
+ Configurations relevant to model architecture and data processing are in [this file](configs/base.yaml). Normally, you do not need to modify these configurations; otherwise, it will cause errors in loading the pre-training checkpoint. Configurations relevant to training are passed through *Command Line Arguments*. Use `python main.py -h ` to see the descriptions. We provide an example of a fine-tuning script in [this file](finetune.sh) (`finetune.sh`). You may need to modify some of the parameters in this file, such as `CUTLASS_PATH` and `WANDB_PROJECT`.
139
+
140
+ Use this to start fine-tuning:
141
+
142
+ ```bash
143
+ source finetune.sh
144
+ ```
145
+
146
+ with `finetune.sh` detailed as below:
147
+
148
+ ```bash
149
+ deepspeed --hostfile=hostfile.txt main.py \
150
+ --deepspeed="./configs/zero2.json" \ # If you want to use DeepSpeed, which is strongly recommended
151
+ --pretrained_model_name_or_path=<MODEL ID | DIRECTORY OF MODEL WEIGHTS | PATH TO MODEL CHECKPOINT> \
152
+ --pretrained_text_encoder_name_or_path=<MODEL ID | PATH TO MODEL DIRECTORY > \ # e.g., google/t5-v1_1-xxl
153
+ --pretrained_vision_encoder_name_or_path=<MODEL ID | PATH TO MODEL DIRECTORY> \ # e.g., google/siglip-so400m-patch14-384
154
+ --output_dir=<DIRECTORY to SAVE CHECKPOINTS> \ # e.g., checkpoints/rdt-1b-agilex
155
+ --train_batch_size=32 \
156
+ --sample_batch_size=64 \ # batch size for diffusion sampling in validation
157
+ --max_train_steps=200000 \
158
+ --checkpointing_period=1000 \
159
+ --sample_period=500 \ # sample period for validation
160
+ --checkpoints_total_limit=40 \
161
+ --lr_scheduler="constant" \
162
+ --learning_rate=1e-4 \
163
+ --mixed_precision="bf16" \ # If you want to use mixed precision, bf16 is recommended
164
+ --dataloader_num_workers=8 \
165
+ --image_aug \ # If you want to use image augmentation
166
+ --dataset_type="finetune" \
167
+ --state_noise_snr=40 \ # If you want to add noise to the state
168
+ --load_from_hdf5 \ # If you use HDF5 to store your data
169
+ --report_to=wandb
170
+ ```
171
+
172
+ **IMPORTANT**: If you have already chosen to precompute the language embeddings, please specify `--precomp_lang_embed` in the `finetune.sh`.
173
+
174
+ Note 1: `pretrained_model_name_or_path` can one of:
175
+
176
+ - a string, the *model id* of a pre-trained model hosted inside a model repo on HuggingFace. Please fill with `"robotics-diffusion-transformer/rdt-1b"`, which is the officially-released [RDT-1B model](https://huggingface.co/robotics-diffusion-transformer/rdt-1b)🤗 at HuggingFace. (recommended)
177
+ - a string, the path to a *directory* containing the manually downloaded model weights from HuggingFace, e.g., `"/path/to/rdt-1b"`. You should first manually download the `rdt-1b` directory from this [link](https://huggingface.co/robotics-diffusion-transformer/rdt-1b)🤗.
178
+ - a string, the path to a *directory* containing model weights saved using [`~RDTRunner.save_pretrained`] method. This can be either:
179
+ - `"checkpoints/rdt-pretrain-1b/checkpoint-<STEP NUMBER>"`: This is the path to the checkpoint saved in the `<STEP NUMBE>` iteration during pre-training. Refer to [this file](docs/pretrain.md) for a tutorial on how to start your own pre-training.
180
+ - `"checkpoints/rdt-pretrain-1b"`: If the pre-training completes normally without any exception, you can specify this path to load the last checkpoint.
181
+ - a string, the path to model checkpoint (`*.pt`) saved by DeepSpeed, e.g., `"checkpoints/rdt-pretrain-1b/checkpoint-<STEP NUMBER>/pytorch_model/mp_rank_00_model_states.pt"` (verified)
182
+ - `None` if you want to randomly initialize the model using configuration at `config_path`.
183
+
184
+ Note 2: You can monitor the training process by observing `loss` (through a long window moving average) and `overall_avg_sample_mse` in [Wandb](https://wandb.ai/site) or [TensorBoard](https://www.tensorflow.org/tensorboard). We empirically found that the lower the `overall_avg_sample_mse`, the better the model performs. Usually, fine-tuning is over when this value converges.
185
+
186
+ Note 3: If the training oscillates, you can increase the batch size by adding more GPUs or setting a larger `--gradient_accumulation_steps`.
187
+
188
+ ## Deployment on Real-Robots
189
+
190
+ We have encapsulated the inference of the model into a class named `RoboticDiffusionTransformerModel` (see [this file](scripts/agilex_model.py#L38)). You can call this class's `step()` method for inference. However, you may need to re-implement some parts according to your specific robot. You should at least modify the `_format_joint_to_state()` (L164) and `_unformat_action_to_joint()` (L196) to convert between robot raw actions and unified action vectors that RDT accepts. You may also specify the control frequency of your robot (L49).
191
+
192
+ **IMPORTANT**: When you feed the images into `step()`, remember the order MUST be `[ext_{t-1}, right_wrist_{t-1}, left_wrist_{t-1}, ext_{t}, right_wrist_{t}, left_wrist_{t}]`.
193
+
194
+ We provide an example hardware code in [this file](scripts/agilex_inference.py) for deployment on Mobile ALOHA, and the corresponding running script in [this file](inference.sh) (`inference.sh`), which is detailed below;
195
+
196
+ ```bash
197
+ python -m scripts.agilex_inference \
198
+ --use_actions_interpolation \
199
+ --pretrained_model_name_or_path=<PATH TO MODEL CHECKPOINT> \ # your finetuned checkpoint: e.g., checkpoints/rdt-finetune-1b/checkpoint-<STEP NUMBER>, checkpoints/rdt-finetune-1b/checkpoint-<STEP NUMBER>/pytorch_model/mp_rank_00_model_states.pt, the same before
200
+ --lang_embeddings_path=<PATH TO YOUR INSTURCTION EMBEDDINGS> \ # e.g. outs/lang_embeddings/your_instr.pt"
201
+ --ctrl_freq=25 # your control frequency
202
+ ```
203
+
204
+ **IMPORTANT**: If you on-board GPU memory is not enough to encode the language, please refer to [this file](scripts/encode_lang.py) for precomputation and specify the language embedding path in `inference.sh`. Detail instructions are provided below:
205
+
206
+ 1. Set Required Parameters in `scripts/encode_lang.py`
207
+
208
+ ```python
209
+ # ...
210
+
211
+ GPU = 0
212
+ MODEL_PATH = "google/t5-v1_1-xxl"
213
+ CONFIG_PATH = "configs/base.yaml"
214
+ SAVE_DIR = "outs/" # output directory
215
+
216
+ # Modify this to your task name and instruction
217
+ TASK_NAME = "handover_pan"
218
+ INSTRUCTION = "Pick up the black marker on the right and put it into the packaging box on the left."
219
+
220
+ # Note: if your GPU VRAM is less than 24GB,
221
+ # it is recommended to enable offloading by specifying an offload directory.
222
+ OFFLOAD_DIR = None # Specify your offload directory here, ensuring the directory exists.
223
+
224
+ # ...
225
+ ```
226
+
227
+ 2. Run the script
228
+ ```
229
+ python -m scripts.encode_lang
230
+ ```
231
+
232
+ Note: If you want to deploy on the Mobile ALOHA robot, don't forget to install the hardware prerequisites (see [this repo](https://github.com/MarkFzp/mobile-aloha)).
233
+
234
+ ## Simulation Benchmark
235
+
236
+ We comprehensively evaluate RDT against baseline methods using the ManiSkill simulation benchmark. Specifically, we focus on five benchmark tasks: `PegInsertionSide`, `PickCube`, `StackCube`, `PlugCharger`, and `PushCube`. Here's a brief overview of the evaluation setup:
237
+
238
+ **Evaluation Setup:**
239
+
240
+ 1. **Install ManiSkill:**
241
+ Within the [RDT environment](#installation), install ManiSkill as follows:
242
+ ```bash
243
+ conda activate rdt
244
+ pip install --upgrade mani_skill
245
+ ```
246
+
247
+ 2. **Configure Vulkan:**
248
+ Follow the [ManiSkill documentation](https://maniskill.readthedocs.io/en/latest/user_guide/getting_started/installation.html#vulkan) to properly set up Vulkan。
249
+
250
+ 3. **Obtain Model Weights:**
251
+ Download the fine-tuned model weights from [this Hugging Face repository](https://huggingface.co/robotics-diffusion-transformer/maniskill-model/tree/main/rdt). Download the precomputed language embeddings from [here](https://huggingface.co/robotics-diffusion-transformer/maniskill-model/tree/main/lang_embeds) to the root directory of this repo.
252
+
253
+ 4. **Run Evaluation Scripts:**
254
+ After completing the setup steps, execute the provided evaluation scripts to assess RDT on the selected tasks.
255
+
256
+ ```
257
+ conda activate rdt
258
+ python -m eval_sim.eval_rdt_maniskill \
259
+ --pretrained_path PATH_TO_PRETRAINED_MODEL
260
+ ```
261
+
262
+ ### Implementation Details
263
+
264
+ #### Data
265
+
266
+ Utilizing the [official ManiSkill repository](https://github.com/haosulab/ManiSkill), we generated 5,000 trajectories through motion planning. The initial action mode of these trajectories is absolute joint position control and we subsequently converted them into delta end-effector pose control to align with the pre-training action space of OpenVLA and Octo. We strictly adhered to the official codebases of OpenVLA and Octo, modifying only the dataset-loading scripts. Consequently, we finetuned OpenVLA and Octo using the delta end-effector pose data. For RDT and Diffusion-Policy we leverage joint position control data for training which is aligned with our pre-training stage as well.
267
+
268
+ #### Training
269
+ - OpenVLA is fine-tuned from the officially released pre-trained checkpoint with LoRA-rank 32 until converge.
270
+ - Octo is fine-tuned from the officially released pre-trained checkpoint for 1M iterations until converge.
271
+ - Diffusion-Policy is trained from scratch for 1000 epochs. We select the checkpoint of 700 epoch which has the lowest validation sample loss of 1e-3.
272
+ - RDT is fine-tuned from our released pre-trained checkpoint for 300ks iterations.
273
+
274
+ #### Results
275
+
276
+ Each method is evaluated over 250 trials (10 random seeds with 25 trials per seed). The quantitative results, including success rate mean and std value across 10 random seeds are presented below:
277
+
278
+
279
+ ||PegInsertionSide|PickCube|StackCube|PlugCharger|PushCube|Mean|
280
+ |---|---|---|---|---|---|---|
281
+ |RDT|**13.2±0.29%**|**77.2±0.48%**|74.0±0.30%|**1.2±0.07%**|**100±0.00%**|**53.6±0.52%**|
282
+ |OpenVLA|0.0±0.00%|8±0.00%|8±0.00%|0.0±0.00%|8±0.00%|4.8±0.00%|
283
+ |Octo|0.0±0.00%|0.0±0.00%|0.0±0.00%|0.0±0.00%|0.0±0.00%|0.0±0.00%|
284
+ |Diffusion-Policy|0.0±0.00%|40.0±0.00%|**80.0±0.00%**|0.0%±0.00%|88.0±0.00%|30.2±0.00%|
285
+
286
+ #### Finetune RDT with Maniskill Data
287
+
288
+ To fine-tune RDT with Maniskill data, first download the Maniskill data from [here](https://huggingface.co/robotics-diffusion-transformer/maniskill-model) and extract it to `data/datasets/rdt-ft-data`. Then copy the code in `data/hdf5_vla_dataset.py` to `data/hdf5_maniskill_dataset.py` and run the following script:
289
+
290
+ ```
291
+ bash finetune_maniskill.sh
292
+ ```
293
+
294
+ #### Reproducing Baseline Results
295
+
296
+ Download and extract the fine-tuned model weights from [here](https://huggingface.co/robotics-diffusion-transformer/maniskill-model) to `eval_sim/`.
297
+
298
+ - OpenVLA: Clone [OpenVLA repo](https://github.com/openvla/openvla) in `./eval_sim/` and install its environment & ManiSkill. Then run the following script:
299
+ ```
300
+ python -m eval_sim.eval_openvla --pretrained_path PATH_TO_PRETRAINED_MODEL
301
+ ```
302
+ - Octo: Clone [Octo repo](https://github.com/octo-models/octo.git) in `./eval_sim/` and install its environment & ManiSkill. The run the following script:
303
+ ```
304
+ python -m eval_sim.eval_octo --pretrained_path PATH_TO_PRETRAINED_MODEL
305
+ ```
306
+ - Diffusion-Policy: Clone our simplified [Diffusion-Policy repo](https://github.com/LBG21/RDT-Eval-Diffusion-Policy) in `./eval_sim/` and run:
307
+ ```
308
+ python -m eval_sim.eval_dp --pretrained_path PATH_TO_PRETRAINED_MODEL
309
+ ```
310
+
311
+ ## FAQ
312
+
313
+ ### 1. How can I fine-tune RDTs with limited VRAM?
314
+
315
+ - **Use a Smaller Model**: Opt for the [RDT-170M model](https://huggingface.co/robotics-diffusion-transformer/rdt-170m), which requires less VRAM.
316
+
317
+ - **Select a Memory-Efficient ZeRO Stage**: Choose a more memory-efficient ZeRO stage based on your needs:
318
+ - **ZeRO-3 with Offload** > **ZeRO-3** > **ZeRO-2 with Offload** > **ZeRO-2** > **ZeRO-1**
319
+ - By default, we use [ZeRO-2](https://github.com/thu-ml/RoboticsDiffusionTransformer/blob/c68398ed526733faca4eec52cc1a7d15a9f8fea7/finetune.sh#L29) for a balance between speed and memory efficiency. Find more details on ZeRO stages [here](https://huggingface.co/docs/transformers/main/deepspeed#select-a-zero-stage) and [here](https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training).
320
+
321
+ - **Enable 8-bit Adam Optimization**: Activate 8-bit Adam by setting [`use_8bit_adam=True`](https://github.com/thu-ml/RoboticsDiffusionTransformer/blob/c68398ed526733faca4eec52cc1a7d15a9f8fea7/main.py#L195) for reduced memory usage during training.
322
+
323
+ - **Apply 4-bit or 8-bit Quantization**: Quantizing model weights can significantly reduce VRAM requirements.
324
+
325
+ - **Use [XFormers](https://github.com/facebookresearch/xformers)**: This library provides optimized transformers with efficient memory usage.
326
+
327
+ - **Enable Gradient Checkpointing**: Implement `gradient_checkpointing` manually to save memory during backpropagation. See [here](https://deepspeed.readthedocs.io/en/latest/activation-checkpointing.html) for instructions. Once you have successfully implemented this feature, we welcome you to submit a PR👏.
328
+ - **Gradient Accumulation**: Set a larger `--gradient_accumulation_steps=<num_steps>`. This will accumulate the gradients of `<num_steps>` batches for backpropagation. Equivalently, this will increase the batch size by `<num_steps>` times, at the cost of `<num_steps>` times the running time.
329
+
330
+ ### 2. How many steps are recommended for fine-tuning RDT?
331
+
332
+ Regardless of the batch size you select, it is recommended to train for at least 150K steps to achieve optimal results.
333
+
334
+ ### 3. What to do if t5-xxL is too large to store in GPU memory?
335
+
336
+ 1. Do not load T5-XXL in your GPU memory when training. Pre-compute language embeddings in advance.
337
+ 2. Set `OFFLOAD_DIR` to enable CPU offloading in `scripts/encode_lang_batch.py` and `scripts/encode_lang.py`.
338
+ 3. Use smaller versions of t5 like t5-base instead of t5-xxL.
339
+
340
+ ## Citation
341
+
342
+ If you find our work helpful, please cite us:
343
+
344
+ ```bibtex
345
+ @article{liu2024rdt,
346
+ title={RDT-1B: a Diffusion Foundation Model for Bimanual Manipulation},
347
+ author={Liu, Songming and Wu, Lingxuan and Li, Bangguo and Tan, Hengkai and Chen, Huayu and Wang, Zhengyi and Xu, Ke and Su, Hang and Zhu, Jun},
348
+ journal={arXiv preprint arXiv:2410.07864},
349
+ year={2024}
350
+ }
351
+ ```
352
+
353
+ Thank you!
354
+
355
+ ## License
356
+
357
+ All the code, model weights, and data are licensed under [MIT license](./LICENSE).
Untitled.ipynb ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "id": "71e6c6b4-1e9b-4abb-b36e-80f7fa919c93",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import h5py\n",
11
+ "import os\n",
12
+ "import fnmatch"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": 5,
18
+ "id": "6bc2e70e-970c-4874-8ce2-5950acbf74d6",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "dataset_name = 'singlevla_benchmark'\n",
23
+ "HDF5_DIR = f\"/home/shared/{dataset_name}/\"\n",
24
+ "DATASET_NAME = dataset_name\n",
25
+ "\n",
26
+ "file_paths = []\n",
27
+ "for root, _, files in os.walk(HDF5_DIR):\n",
28
+ " for filename in fnmatch.filter(files, '*.hdf5'):\n",
29
+ " file_path = os.path.join(root, filename)\n",
30
+ " file_paths.append(file_path)"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 7,
36
+ "id": "a43097df-bd96-4300-9473-748ce19406c4",
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "f = h5py.File(file_paths[0], 'r')"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 9,
46
+ "id": "3bb5763f-4156-4dba-ac19-66361837afbd",
47
+ "metadata": {},
48
+ "outputs": [
49
+ {
50
+ "data": {
51
+ "text/plain": [
52
+ "<KeysViewHDF5 ['ee_pos', 'joint_pos', 'leftview_image', 'rightview_image']>"
53
+ ]
54
+ },
55
+ "execution_count": 9,
56
+ "metadata": {},
57
+ "output_type": "execute_result"
58
+ }
59
+ ],
60
+ "source": [
61
+ "f['observation'].keys()"
62
+ ]
63
+ }
64
+ ],
65
+ "metadata": {
66
+ "kernelspec": {
67
+ "display_name": "Python 3 (ipykernel)",
68
+ "language": "python",
69
+ "name": "python3"
70
+ },
71
+ "language_info": {
72
+ "codemirror_mode": {
73
+ "name": "ipython",
74
+ "version": 3
75
+ },
76
+ "file_extension": ".py",
77
+ "mimetype": "text/x-python",
78
+ "name": "python",
79
+ "nbconvert_exporter": "python",
80
+ "pygments_lexer": "ipython3",
81
+ "version": "3.10.15"
82
+ }
83
+ },
84
+ "nbformat": 4,
85
+ "nbformat_minor": 5
86
+ }
data/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Ignore data files
2
+ datasets
data/compute_dataset_stat.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file will compute the min, max, mean, and standard deviation of each datasets
3
+ in `pretrain_datasets.json` or `pretrain_datasets.json`.
4
+ """
5
+
6
+ import json
7
+ import argparse
8
+ import os
9
+ # from multiprocessing import Pool, Manager
10
+
11
+ import tensorflow as tf
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+
15
+ from data.vla_dataset import VLADataset
16
+ from data.hdf5_vla_dataset import HDF5VLADataset
17
+ from data.preprocess import generate_json_state
18
+
19
+
20
+ # Process each dataset to get the statistics
21
+ @tf.autograph.experimental.do_not_convert
22
+ def process_dataset(name_dataset_pair):
23
+ # print(f"PID {os.getpid()} processing {name_dataset_pair[0]}")
24
+ dataset_iter = name_dataset_pair[1]
25
+
26
+ MAX_EPISODES = 100000
27
+ EPS = 1e-8
28
+ # For debugging
29
+ # MAX_EPISODES = 10
30
+ episode_cnt = 0
31
+ state_sum = 0
32
+ state_sum_sq = 0
33
+ z_state_sum = 0
34
+ z_state_sum_sq = 0
35
+ state_cnt = 0
36
+ nz_state_cnt = None
37
+ state_max = None
38
+ state_min = None
39
+ for episode in dataset_iter:
40
+ episode_cnt += 1
41
+ if episode_cnt % 1000 == 0:
42
+ print(f"Processing episodes {episode_cnt}/{MAX_EPISODES}")
43
+ if episode_cnt > MAX_EPISODES:
44
+ break
45
+ episode_dict = episode['episode_dict']
46
+ dataset_name = episode['dataset_name']
47
+
48
+ res_tup = generate_json_state(
49
+ episode_dict, dataset_name
50
+ )
51
+ states = res_tup[1]
52
+
53
+ # Convert to numpy
54
+ states = states.numpy()
55
+
56
+ # Zero the values that are close to zero
57
+ z_states = states.copy()
58
+ z_states[np.abs(states) <= EPS] = 0
59
+ # Compute the non-zero count
60
+ if nz_state_cnt is None:
61
+ nz_state_cnt = np.zeros(states.shape[1])
62
+ nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
63
+
64
+ # Update statistics
65
+ state_sum += np.sum(states, axis=0)
66
+ state_sum_sq += np.sum(states**2, axis=0)
67
+ z_state_sum += np.sum(z_states, axis=0)
68
+ z_state_sum_sq += np.sum(z_states**2, axis=0)
69
+ state_cnt += states.shape[0]
70
+ if state_max is None:
71
+ state_max = np.max(states, axis=0)
72
+ state_min = np.min(states, axis=0)
73
+ else:
74
+ state_max = np.maximum(state_max, np.max(states, axis=0))
75
+ state_min = np.minimum(state_min, np.min(states, axis=0))
76
+
77
+ # Add one to avoid division by zero
78
+ nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
79
+
80
+ result = {
81
+ "dataset_name": name_dataset_pair[0],
82
+ "state_mean": (state_sum / state_cnt).tolist(),
83
+ "state_std": np.sqrt(
84
+ np.maximum(
85
+ (z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
86
+ np.zeros_like(state_sum_sq)
87
+ )
88
+ ).tolist(),
89
+ "state_min": state_min.tolist(),
90
+ "state_max": state_max.tolist(),
91
+ }
92
+
93
+ return result
94
+
95
+
96
+ def process_hdf5_dataset(vla_dataset):
97
+ EPS = 1e-8
98
+ episode_cnt = 0
99
+ state_sum = 0
100
+ state_sum_sq = 0
101
+ z_state_sum = 0
102
+ z_state_sum_sq = 0
103
+ state_cnt = 0
104
+ nz_state_cnt = None
105
+ state_max = None
106
+ state_min = None
107
+ for i in tqdm(range(len(vla_dataset))):
108
+ episode = vla_dataset.get_item(i, state_only=True)
109
+ episode_cnt += 1
110
+
111
+ states = episode['state']
112
+
113
+ # Zero the values that are close to zero
114
+ z_states = states.copy()
115
+ z_states[np.abs(states) <= EPS] = 0
116
+ # Compute the non-zero count
117
+ if nz_state_cnt is None:
118
+ nz_state_cnt = np.zeros(states.shape[1])
119
+ nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
120
+
121
+ # Update statistics
122
+ state_sum += np.sum(states, axis=0)
123
+ state_sum_sq += np.sum(states**2, axis=0)
124
+ z_state_sum += np.sum(z_states, axis=0)
125
+ z_state_sum_sq += np.sum(z_states**2, axis=0)
126
+ state_cnt += states.shape[0]
127
+ if state_max is None:
128
+ state_max = np.max(states, axis=0)
129
+ state_min = np.min(states, axis=0)
130
+ else:
131
+ state_max = np.maximum(state_max, np.max(states, axis=0))
132
+ state_min = np.minimum(state_min, np.min(states, axis=0))
133
+
134
+ # Add one to avoid division by zero
135
+ nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
136
+
137
+ result = {
138
+ "dataset_name": vla_dataset.get_dataset_name(),
139
+ "state_mean": (state_sum / state_cnt).tolist(),
140
+ "state_std": np.sqrt(
141
+ np.maximum(
142
+ (z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
143
+ np.zeros_like(state_sum_sq)
144
+ )
145
+ ).tolist(),
146
+ "state_min": state_min.tolist(),
147
+ "state_max": state_max.tolist(),
148
+ }
149
+
150
+ return result
151
+
152
+
153
+ if __name__ == "__main__":
154
+ parser = argparse.ArgumentParser()
155
+ # Multiprocessing currently with bugs
156
+ # parser.add_argument('--n_workers', type=int, default=1,
157
+ # help="Number of parallel workers.")
158
+ parser.add_argument('--dataset_type', type=str,
159
+ default="pretrain",
160
+ help="Whether to load the pretrain dataset or finetune dataset.")
161
+ parser.add_argument('--save_path', type=str,
162
+ default="configs/dataset_stat.json",
163
+ help="JSON file path to save the dataset statistics.")
164
+ parser.add_argument('--skip_exist', action='store_true',
165
+ help="Whether to skip the existing dataset statistics.")
166
+ parser.add_argument('--hdf5_dataset', action='store_true',
167
+ help="Whether to load the dataset from the HDF5 files.")
168
+ args = parser.parse_args()
169
+
170
+ if args.hdf5_dataset:
171
+ vla_dataset = HDF5VLADataset()
172
+ dataset_name = vla_dataset.get_dataset_name()
173
+
174
+ try:
175
+ with open(args.save_path, 'r') as f:
176
+ results = json.load(f)
177
+ except FileNotFoundError:
178
+ results = {}
179
+ if args.skip_exist and dataset_name in results:
180
+ print(f"Skipping existed {dataset_name} dataset statistics")
181
+ else:
182
+ print(f"Processing {dataset_name} dataset")
183
+ result = process_hdf5_dataset(vla_dataset)
184
+ results[result["dataset_name"]] = result
185
+ with open(args.save_path, 'w') as f:
186
+ json.dump(results, f, indent=4)
187
+ print("All datasets have been processed.")
188
+ os._exit(0)
189
+
190
+ vla_dataset = VLADataset(
191
+ seed=0, dataset_type=args.dataset_type, repeat=False)
192
+ name_dataset_pairs = vla_dataset.name2dataset.items()
193
+ # num_workers = args.n_workers
194
+
195
+ for name_dataset_pair in tqdm(name_dataset_pairs):
196
+ try:
197
+ with open(args.save_path, 'r') as f:
198
+ results = json.load(f)
199
+ except FileNotFoundError:
200
+ results = {}
201
+
202
+ if args.skip_exist and name_dataset_pair[0] in results:
203
+ print(f"Skipping existed {name_dataset_pair[0]} dataset statistics")
204
+ continue
205
+ print(f"Processing {name_dataset_pair[0]} dataset")
206
+
207
+ result = process_dataset(name_dataset_pair)
208
+
209
+ results[result["dataset_name"]] = result
210
+
211
+ # Save the results in the json file after each dataset (for resume)
212
+ with open(args.save_path, 'w') as f:
213
+ json.dump(results, f, indent=4)
214
+
215
+ print("All datasets have been processed.")
216
+
217
+ # with Manager() as manager:
218
+ # # Create shared dictionary and lock through the manager, accessible by all processes
219
+ # progress = manager.dict(processed=0, results={})
220
+ # progress_lock = manager.Lock()
221
+
222
+ # # Callback function to update progress
223
+ # def update_progress(result):
224
+ # with progress_lock:
225
+ # progress['processed'] += 1
226
+ # print(f"{result['dataset_name']} - {progress['processed']}/{len(name_dataset_pairs)} datasets have been processed")
227
+ # # Append the result to the shared dictionary
228
+ # progress['results'][result["dataset_name"]] = result
229
+
230
+ # with Pool(num_workers) as p:
231
+ # for name_dataset_pair in name_dataset_pairs:
232
+ # p.apply_async(process_dataset, args=(name_dataset_pair,), callback=update_progress)
233
+
234
+ # # Close the pool and wait for the work to finish
235
+ # p.close()
236
+ # p.join()
237
+
238
+ # # Save the results in the json file
239
+ # with open(args.save_path, 'w') as f:
240
+ # json.dump(progress['results'], f, indent=4)
data/compute_dataset_stat_hdf5.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file will compute the min, max, mean, and standard deviation of each datasets
3
+ in `pretrain_datasets.json` or `pretrain_datasets.json`.
4
+ """
5
+
6
+ import json
7
+ import argparse
8
+
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ # from data.hdf5_vla_dataset import TabletopHDF5VLADataset as HDF5VLADataset
13
+ from data.hdf5_vla_dataset import AnubisHDF5VLADataset as HDF5VLADataset
14
+
15
+
16
+ def process_hdf5_dataset(vla_dataset):
17
+ EPS = 1e-8
18
+ episode_cnt = 0
19
+ state_sum = 0
20
+ state_sum_sq = 0
21
+ z_state_sum = 0
22
+ z_state_sum_sq = 0
23
+ state_cnt = 0
24
+ nz_state_cnt = None
25
+ state_max = None
26
+ state_min = None
27
+ for i in tqdm(range(len(vla_dataset))):
28
+ # print(i)
29
+ episode = vla_dataset.get_item(i, state_only=True)
30
+ episode_cnt += 1
31
+
32
+ states = episode['state']
33
+
34
+ # Zero the values that are close to zero
35
+ z_states = states.copy()
36
+ z_states[np.abs(states) <= EPS] = 0
37
+ # Compute the non-zero count
38
+ if nz_state_cnt is None:
39
+ nz_state_cnt = np.zeros(states.shape[1])
40
+ nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
41
+
42
+ # Update statistics
43
+ state_sum += np.sum(states, axis=0)
44
+ state_sum_sq += np.sum(states**2, axis=0)
45
+ z_state_sum += np.sum(z_states, axis=0)
46
+ z_state_sum_sq += np.sum(z_states**2, axis=0)
47
+ state_cnt += states.shape[0]
48
+ if state_max is None:
49
+ state_max = np.max(states, axis=0)
50
+ state_min = np.min(states, axis=0)
51
+ else:
52
+ state_max = np.maximum(state_max, np.max(states, axis=0))
53
+ state_min = np.minimum(state_min, np.min(states, axis=0))
54
+
55
+ # Add one to avoid division by zero
56
+ nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
57
+
58
+ result = {
59
+ "dataset_name": vla_dataset.get_dataset_name(),
60
+ "state_mean": (state_sum / state_cnt).tolist(),
61
+ "state_std": np.sqrt(
62
+ np.maximum(
63
+ (z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
64
+ np.zeros_like(state_sum_sq)
65
+ )
66
+ ).tolist(),
67
+ "state_min": state_min.tolist(),
68
+ "state_max": state_max.tolist(),
69
+ }
70
+
71
+ return result
72
+
73
+
74
+ if __name__ == "__main__":
75
+ parser = argparse.ArgumentParser()
76
+ parser.add_argument('--save_path', type=str,
77
+ default="configs/dataset_stat.json",
78
+ help="JSON file path to save the dataset statistics.")
79
+ parser.add_argument('--skip_exist', action='store_true',
80
+ help="Whether to skip the existing dataset statistics.")
81
+ parser.add_argument('--dataset', type=str)
82
+ args = parser.parse_args()
83
+
84
+ vla_dataset = HDF5VLADataset(args.dataset)
85
+ dataset_name = vla_dataset.get_dataset_name()
86
+
87
+ try:
88
+ with open(args.save_path, 'r') as f:
89
+ results = json.load(f)
90
+ except FileNotFoundError:
91
+ results = {}
92
+ if args.skip_exist and dataset_name in results:
93
+ print(f"Skipping existed {dataset_name} dataset statistics")
94
+ else:
95
+ print(f"Processing {dataset_name} dataset")
96
+ result = process_hdf5_dataset(vla_dataset)
97
+ results[result["dataset_name"]] = result
98
+ with open(args.save_path, 'w') as f:
99
+ json.dump(results, f, indent=4)
100
+ print("All datasets have been processed.")
data/episode_transform.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import yaml
4
+
5
+ from data.preprocess import generate_json_state
6
+ from configs.state_vec import STATE_VEC_IDX_MAPPING
7
+
8
+
9
+ # Read the config
10
+ with open('configs/base.yaml', 'r') as file:
11
+ config = yaml.safe_load(file)
12
+ # Load some constants from the config
13
+ IMG_HISTORY_SIZE = config['common']['img_history_size']
14
+ if IMG_HISTORY_SIZE < 1:
15
+ raise ValueError("Config `img_history_size` must be at least 1.")
16
+ ACTION_CHUNK_SIZE = config['common']['action_chunk_size']
17
+ if ACTION_CHUNK_SIZE < 1:
18
+ raise ValueError("Config `action_chunk_size` must be at least 1.")
19
+
20
+
21
+ @tf.function
22
+ def process_episode(epsd: dict, dataset_name: str,
23
+ image_keys: list, image_mask: list) -> dict:
24
+ """
25
+ Process an episode to extract the frames and the json content.
26
+ """
27
+ # Frames of each camera
28
+ # Ugly code due to tf's poor compatibility
29
+ frames_0 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
30
+ frames_1 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
31
+ frames_2 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
32
+ frames_3 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
33
+ # Traverse the episode to collect...
34
+ for step in iter(epsd['steps']):
35
+ # Parse the image
36
+ frames_0 = frames_0.write(frames_0.size(),
37
+ tf.cond(
38
+ tf.equal(image_mask[0], 1),
39
+ lambda: step['observation'][image_keys[0]],
40
+ lambda: tf.zeros([0, 0, 0], dtype=tf.uint8)
41
+ ))
42
+ # Very ugly code due to tf's poor compatibility
43
+ frames_1 = frames_1.write(frames_1.size(),
44
+ tf.cond(
45
+ tf.equal(image_mask[1], 1),
46
+ lambda: step['observation'][image_keys[1]],
47
+ lambda: tf.zeros([0, 0, 0], dtype=tf.uint8)
48
+ ))
49
+ # print(image_mask)
50
+ frames_2 = frames_2.write(frames_2.size(),
51
+ tf.cond(
52
+ tf.equal(image_mask[2], 1),
53
+ lambda: step['observation'][image_keys[2]],
54
+ lambda: tf.zeros([0, 0, 0], dtype=tf.uint8)
55
+ ))
56
+ frames_3 = frames_3.write(frames_3.size(),
57
+ tf.cond(
58
+ tf.equal(image_mask[3], 1),
59
+ lambda: step['observation'][image_keys[3]],
60
+ lambda: tf.zeros([0, 0, 0], dtype=tf.uint8)
61
+ ))
62
+
63
+
64
+ # Calculate the past_frames_0 for each step
65
+ # Each step has a window of previous frames with size IMG_HISTORY_SIZE
66
+ # Use the first state to pad the frames
67
+ # past_frames_0 will have shape (num_steps, IMG_HISTORY_SIZE, height, width, channels)
68
+ frames_0 = frames_0.stack()
69
+ first_frame = tf.expand_dims(frames_0[0], axis=0)
70
+ first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE-1, axis=0)
71
+ padded_frames_0 = tf.concat([first_frame, frames_0], axis=0)
72
+ indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_0)[0] + IMG_HISTORY_SIZE)
73
+ past_frames_0 = tf.map_fn(
74
+ lambda i: padded_frames_0[i - IMG_HISTORY_SIZE:i],
75
+ indices,
76
+ dtype=tf.uint8
77
+ )
78
+ frames_0_time_mask = tf.ones([tf.shape(frames_0)[0]], dtype=tf.bool)
79
+ padded_frames_0_time_mask = tf.pad(frames_0_time_mask, [[IMG_HISTORY_SIZE-1, 0]], "CONSTANT", constant_values=False)
80
+ past_frames_0_time_mask = tf.map_fn(
81
+ lambda i: padded_frames_0_time_mask[i - IMG_HISTORY_SIZE:i],
82
+ indices,
83
+ dtype=tf.bool
84
+ )
85
+
86
+ # For past_frames_1
87
+ frames_1 = frames_1.stack()
88
+ first_frame = tf.expand_dims(frames_1[0], axis=0)
89
+ first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE-1, axis=0)
90
+ padded_frames_1 = tf.concat([first_frame, frames_1], axis=0)
91
+ indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_1)[0] + IMG_HISTORY_SIZE)
92
+ past_frames_1 = tf.map_fn(
93
+ lambda i: padded_frames_1[i - IMG_HISTORY_SIZE:i],
94
+ indices,
95
+ dtype=tf.uint8
96
+ )
97
+ frames_1_time_mask = tf.ones([tf.shape(frames_1)[0]], dtype=tf.bool)
98
+ padded_frames_1_time_mask = tf.pad(frames_1_time_mask, [[IMG_HISTORY_SIZE-1, 0]], "CONSTANT", constant_values=False)
99
+ past_frames_1_time_mask = tf.map_fn(
100
+ lambda i: padded_frames_1_time_mask[i - IMG_HISTORY_SIZE:i],
101
+ indices,
102
+ dtype=tf.bool
103
+ )
104
+
105
+ # For past_frames_2
106
+ frames_2 = frames_2.stack()
107
+ first_frame = tf.expand_dims(frames_2[0], axis=0)
108
+ first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE-1, axis=0)
109
+ padded_frames_2 = tf.concat([first_frame, frames_2], axis=0)
110
+ indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_2)[0] + IMG_HISTORY_SIZE)
111
+ past_frames_2 = tf.map_fn(
112
+ lambda i: padded_frames_2[i - IMG_HISTORY_SIZE:i],
113
+ indices,
114
+ dtype=tf.uint8
115
+ )
116
+ frames_2_time_mask = tf.ones([tf.shape(frames_2)[0]], dtype=tf.bool)
117
+ padded_frames_2_time_mask = tf.pad(frames_2_time_mask, [[IMG_HISTORY_SIZE-1, 0]], "CONSTANT", constant_values=False)
118
+ past_frames_2_time_mask = tf.map_fn(
119
+ lambda i: padded_frames_2_time_mask[i - IMG_HISTORY_SIZE:i],
120
+ indices,
121
+ dtype=tf.bool
122
+ )
123
+
124
+ # For past_frames_3
125
+ frames_3 = frames_3.stack()
126
+ first_frame = tf.expand_dims(frames_3[0], axis=0)
127
+ first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE-1, axis=0)
128
+ padded_frames_3 = tf.concat([first_frame, frames_3], axis=0)
129
+ indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_3)[0] + IMG_HISTORY_SIZE)
130
+ past_frames_3 = tf.map_fn(
131
+ lambda i: padded_frames_3[i - IMG_HISTORY_SIZE:i],
132
+ indices,
133
+ dtype=tf.uint8
134
+ )
135
+ frames_3_time_mask = tf.ones([tf.shape(frames_3)[0]], dtype=tf.bool)
136
+ padded_frames_3_time_mask = tf.pad(frames_3_time_mask, [[IMG_HISTORY_SIZE-1, 0]], "CONSTANT", constant_values=False)
137
+ past_frames_3_time_mask = tf.map_fn(
138
+ lambda i: padded_frames_3_time_mask[i - IMG_HISTORY_SIZE:i],
139
+ indices,
140
+ dtype=tf.bool
141
+ )
142
+
143
+ # Creat the ids for each step
144
+ step_id = tf.range(0, tf.shape(frames_0)[0])
145
+
146
+ return {
147
+ 'dataset_name': dataset_name,
148
+ 'episode_dict': epsd,
149
+ 'step_id': step_id,
150
+ 'past_frames_0': past_frames_0,
151
+ 'past_frames_0_time_mask': past_frames_0_time_mask,
152
+ 'past_frames_1': past_frames_1,
153
+ 'past_frames_1_time_mask': past_frames_1_time_mask,
154
+ 'past_frames_2': past_frames_2,
155
+ 'past_frames_2_time_mask': past_frames_2_time_mask,
156
+ 'past_frames_3': past_frames_3,
157
+ 'past_frames_3_time_mask': past_frames_3_time_mask,
158
+ }
159
+
160
+
161
+ @tf.function
162
+ def bgr_to_rgb(epsd: dict):
163
+ """
164
+ Convert BGR images to RGB images.
165
+ """
166
+ past_frames_0 = epsd['past_frames_0']
167
+ past_frames_0 = tf.cond(
168
+ tf.equal(tf.shape(past_frames_0)[-1], 3),
169
+ lambda: tf.stack([
170
+ past_frames_0[..., 2],
171
+ past_frames_0[..., 1],
172
+ past_frames_0[..., 0]
173
+ ], axis=-1),
174
+ lambda: past_frames_0
175
+ )
176
+
177
+ past_frames_1 = epsd['past_frames_1']
178
+ past_frames_1 = tf.cond(
179
+ tf.equal(tf.shape(past_frames_1)[-1], 3),
180
+ lambda: tf.stack([
181
+ past_frames_1[..., 2],
182
+ past_frames_1[..., 1],
183
+ past_frames_1[..., 0]
184
+ ], axis=-1),
185
+ lambda: past_frames_1
186
+ )
187
+
188
+ past_frames_2 = epsd['past_frames_2']
189
+ past_frames_2 = tf.cond(
190
+ tf.equal(tf.shape(past_frames_2)[-1], 3),
191
+ lambda: tf.stack([
192
+ past_frames_2[..., 2],
193
+ past_frames_2[..., 1],
194
+ past_frames_2[..., 0]
195
+ ], axis=-1),
196
+ lambda: past_frames_2
197
+ )
198
+
199
+ past_frames_3 = epsd['past_frames_3']
200
+ past_frames_3 = tf.cond(
201
+ tf.equal(tf.shape(past_frames_3)[-1], 3),
202
+ lambda: tf.stack([
203
+ past_frames_3[..., 2],
204
+ past_frames_3[..., 1],
205
+ past_frames_3[..., 0]
206
+ ], axis=-1),
207
+ lambda: past_frames_3
208
+ )
209
+
210
+ return {
211
+ 'dataset_name': epsd['dataset_name'],
212
+ 'episode_dict': epsd['episode_dict'],
213
+ 'step_id': epsd['step_id'],
214
+ 'past_frames_0': past_frames_0,
215
+ 'past_frames_0_time_mask': epsd['past_frames_0_time_mask'],
216
+ 'past_frames_1': past_frames_1,
217
+ 'past_frames_1_time_mask': epsd['past_frames_1_time_mask'],
218
+ 'past_frames_2': past_frames_2,
219
+ 'past_frames_2_time_mask': epsd['past_frames_2_time_mask'],
220
+ 'past_frames_3': past_frames_3,
221
+ 'past_frames_3_time_mask': epsd['past_frames_3_time_mask'],
222
+ }
223
+
224
+
225
+ def flatten_episode(episode: dict) -> tf.data.Dataset:
226
+ """
227
+ Flatten the episode to a list of steps.
228
+ """
229
+ episode_dict = episode['episode_dict']
230
+ dataset_name = episode['dataset_name']
231
+
232
+ json_content, states, masks = generate_json_state(
233
+ episode_dict, dataset_name
234
+ )
235
+
236
+ # Calculate the past_states for each step
237
+ # Each step has a window of previous states with size ACTION_CHUNK_SIZE
238
+ # Use the first state to pad the states
239
+ # past_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
240
+ first_state = tf.expand_dims(states[0], axis=0)
241
+ first_state = tf.repeat(first_state, ACTION_CHUNK_SIZE-1, axis=0)
242
+ padded_states = tf.concat([first_state, states], axis=0)
243
+ indices = tf.range(ACTION_CHUNK_SIZE, tf.shape(states)[0] + ACTION_CHUNK_SIZE)
244
+ past_states = tf.map_fn(
245
+ lambda i: padded_states[i - ACTION_CHUNK_SIZE:i],
246
+ indices,
247
+ dtype=tf.float32
248
+ )
249
+ states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
250
+ padded_states_time_mask = tf.pad(states_time_mask, [[ACTION_CHUNK_SIZE-1, 0]], "CONSTANT", constant_values=False)
251
+ past_states_time_mask = tf.map_fn(
252
+ lambda i: padded_states_time_mask[i - ACTION_CHUNK_SIZE:i],
253
+ indices,
254
+ dtype=tf.bool
255
+ )
256
+
257
+ # Calculate the future_states for each step
258
+ # Each step has a window of future states with size ACTION_CHUNK_SIZE
259
+ # Use the last state to pad the states
260
+ # future_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
261
+ last_state = tf.expand_dims(states[-1], axis=0)
262
+ last_state = tf.repeat(last_state, ACTION_CHUNK_SIZE, axis=0)
263
+ padded_states = tf.concat([states, last_state], axis=0)
264
+ indices = tf.range(1, tf.shape(states)[0] + 1)
265
+ future_states = tf.map_fn(
266
+ lambda i: padded_states[i:i + ACTION_CHUNK_SIZE],
267
+ indices,
268
+ dtype=tf.float32
269
+ )
270
+ states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
271
+ padded_states_time_mask = tf.pad(states_time_mask, [[0, ACTION_CHUNK_SIZE]], "CONSTANT", constant_values=False)
272
+ future_states_time_mask = tf.map_fn(
273
+ lambda i: padded_states_time_mask[i:i + ACTION_CHUNK_SIZE],
274
+ indices,
275
+ dtype=tf.bool
276
+ )
277
+
278
+ # Calculate the mean and std for state
279
+ state_std = tf.math.reduce_std(states, axis=0, keepdims=True)
280
+ state_std = tf.repeat(state_std, tf.shape(states)[0], axis=0)
281
+ state_mean = tf.math.reduce_mean(states, axis=0, keepdims=True)
282
+ state_mean = tf.repeat(state_mean, tf.shape(states)[0], axis=0)
283
+
284
+ state_norm = tf.math.reduce_mean(
285
+ tf.math.square(states), axis=0, keepdims=True)
286
+ state_norm = tf.math.sqrt(state_norm)
287
+ state_norm = tf.repeat(state_norm, tf.shape(states)[0], axis=0)
288
+
289
+ # Create a list of steps
290
+ step_data = []
291
+ for i in range(tf.shape(states)[0]):
292
+ step_data.append({
293
+ 'step_id': episode['step_id'][i],
294
+ 'json_content': json_content,
295
+ 'state_chunk': past_states[i],
296
+ 'state_chunk_time_mask': past_states_time_mask[i],
297
+ 'action_chunk': future_states[i],
298
+ 'action_chunk_time_mask': future_states_time_mask[i],
299
+ 'state_vec_mask': masks[i],
300
+ 'past_frames_0': episode['past_frames_0'][i],
301
+ 'past_frames_0_time_mask': episode['past_frames_0_time_mask'][i],
302
+ 'past_frames_1': episode['past_frames_1'][i],
303
+ 'past_frames_1_time_mask': episode['past_frames_1_time_mask'][i],
304
+ 'past_frames_2': episode['past_frames_2'][i],
305
+ 'past_frames_2_time_mask': episode['past_frames_2_time_mask'][i],
306
+ 'past_frames_3': episode['past_frames_3'][i],
307
+ 'past_frames_3_time_mask': episode['past_frames_3_time_mask'][i],
308
+ 'state_std': state_std[i],
309
+ 'state_mean': state_mean[i],
310
+ 'state_norm': state_norm[i],
311
+ })
312
+
313
+ return step_data
314
+
315
+
316
+ def flatten_episode_agilex(episode: dict) -> tf.data.Dataset:
317
+ """
318
+ Flatten the episode to a list of steps.
319
+ """
320
+ episode_dict = episode['episode_dict']
321
+ dataset_name = episode['dataset_name']
322
+
323
+ json_content, states, masks, acts = generate_json_state(
324
+ episode_dict, dataset_name
325
+ )
326
+
327
+ # Calculate the past_states for each step
328
+ # Each step has a window of previous states with size ACTION_CHUNK_SIZE
329
+ # Use the first state to pad the states
330
+ # past_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
331
+ first_state = tf.expand_dims(states[0], axis=0)
332
+ first_state = tf.repeat(first_state, ACTION_CHUNK_SIZE-1, axis=0)
333
+ padded_states = tf.concat([first_state, states], axis=0)
334
+ indices = tf.range(ACTION_CHUNK_SIZE, tf.shape(states)[0] + ACTION_CHUNK_SIZE)
335
+ past_states = tf.map_fn(
336
+ lambda i: padded_states[i - ACTION_CHUNK_SIZE:i],
337
+ indices,
338
+ dtype=tf.float32
339
+ )
340
+ states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
341
+ padded_states_time_mask = tf.pad(states_time_mask, [[ACTION_CHUNK_SIZE-1, 0]], "CONSTANT", constant_values=False)
342
+ past_states_time_mask = tf.map_fn(
343
+ lambda i: padded_states_time_mask[i - ACTION_CHUNK_SIZE:i],
344
+ indices,
345
+ dtype=tf.bool
346
+ )
347
+
348
+ # NOTE bg the future states shall be actions
349
+ # Calculate the future_states for each step
350
+ # Each step has a window of future states with size ACTION_CHUNK_SIZE
351
+ # Use the last action to pad the states
352
+ # future_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
353
+ last_act = tf.expand_dims(acts[-1], axis=0)
354
+ last_act = tf.repeat(last_act, ACTION_CHUNK_SIZE, axis=0)
355
+ padded_states = tf.concat([acts, last_act], axis=0)
356
+ # indices = tf.range(1, tf.shape(states)[0] + 1)
357
+ indices = tf.range(0, tf.shape(acts)[0]) # NOTE time 0 action = time 1 state
358
+ future_states = tf.map_fn(
359
+ lambda i: padded_states[i:i + ACTION_CHUNK_SIZE],
360
+ indices,
361
+ dtype=tf.float32
362
+ )
363
+ states_time_mask = tf.ones([tf.shape(acts)[0]], dtype=tf.bool)
364
+ padded_states_time_mask = tf.pad(states_time_mask, [[0, ACTION_CHUNK_SIZE]], "CONSTANT", constant_values=False)
365
+ future_states_time_mask = tf.map_fn(
366
+ lambda i: padded_states_time_mask[i:i + ACTION_CHUNK_SIZE],
367
+ indices,
368
+ dtype=tf.bool
369
+ )
370
+
371
+ # Calculate the std and mean for state
372
+ state_std = tf.math.reduce_std(states, axis=0, keepdims=True)
373
+ state_std = tf.repeat(state_std, tf.shape(states)[0], axis=0)
374
+ state_mean = tf.math.reduce_mean(states, axis=0, keepdims=True)
375
+ state_mean = tf.repeat(state_mean, tf.shape(states)[0], axis=0)
376
+
377
+ state_norm = tf.math.reduce_mean(
378
+ tf.math.square(acts), axis=0, keepdims=True)
379
+ state_norm = tf.math.sqrt(state_norm)
380
+ state_norm = tf.repeat(state_norm, tf.shape(states)[0], axis=0)
381
+
382
+ # Create a list of steps
383
+ step_data = []
384
+ for i in range(tf.shape(states)[0]):
385
+ step_data.append({
386
+ 'step_id': episode['step_id'][i],
387
+ 'json_content': json_content,
388
+ 'state_chunk': past_states[i],
389
+ 'state_chunk_time_mask': past_states_time_mask[i],
390
+ 'action_chunk': future_states[i],
391
+ 'action_chunk_time_mask': future_states_time_mask[i],
392
+ 'state_vec_mask': masks[i],
393
+ 'past_frames_0': episode['past_frames_0'][i],
394
+ 'past_frames_0_time_mask': episode['past_frames_0_time_mask'][i],
395
+ 'past_frames_1': episode['past_frames_1'][i],
396
+ 'past_frames_1_time_mask': episode['past_frames_1_time_mask'][i],
397
+ 'past_frames_2': episode['past_frames_2'][i],
398
+ 'past_frames_2_time_mask': episode['past_frames_2_time_mask'][i],
399
+ 'past_frames_3': episode['past_frames_3'][i],
400
+ 'past_frames_3_time_mask': episode['past_frames_3_time_mask'][i],
401
+ 'state_std': state_std[i],
402
+ 'state_mean': state_mean[i],
403
+ 'state_norm': state_norm[i],
404
+ })
405
+
406
+ return step_data
data/filelock.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fcntl
2
+
3
+
4
+ class FileLock:
5
+ """
6
+ A file lock class.
7
+ """
8
+ def __init__(self, filename):
9
+ self.filename = filename
10
+ self.handle = None
11
+
12
+ def acquire_read_lock(self):
13
+ self.handle = open(self.filename + '.lock', 'r')
14
+ fcntl.flock(self.handle, fcntl.LOCK_SH | fcntl.LOCK_NB)
15
+
16
+ def acquire_write_lock(self):
17
+ self.handle = open(self.filename + '.lock', 'w')
18
+ fcntl.flock(self.handle, fcntl.LOCK_EX | fcntl.LOCK_NB)
19
+
20
+ def release_lock(self):
21
+ if self.handle is not None:
22
+ fcntl.flock(self.handle, fcntl.LOCK_UN)
23
+ self.handle.close()
24
+ self.handle = None
data/hdf5_maniskill_dataset.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import h5py
3
+ import yaml
4
+ import numpy as np
5
+ # Assuming STATE_VEC_IDX_MAPPING is a dictionary mapping state variable names to indices
6
+ from configs.state_vec import STATE_VEC_IDX_MAPPING
7
+ import glob
8
+ from scipy.interpolate import interp1d
9
+ from PIL import Image
10
+
11
+
12
+ def interpolate_action_sequence(action_sequence, target_size):
13
+ """
14
+ Extend the action sequece to `target_size` by linear interpolation.
15
+
16
+ Args:
17
+ action_sequence (np.ndarray): original action sequence, shape (N, D).
18
+ target_size (int): target sequence length.
19
+
20
+ Returns:
21
+ extended_sequence (np.ndarray): extended action sequence, shape (target_size, D).
22
+ """
23
+ N, D = action_sequence.shape
24
+ indices_old = np.arange(N)
25
+ indices_new = np.linspace(0, N - 1, target_size)
26
+
27
+ interp_func = interp1d(indices_old, action_sequence,
28
+ kind='linear', axis=0, assume_sorted=True)
29
+ action_sequence_new = interp_func(indices_new)
30
+
31
+ return action_sequence_new
32
+
33
+
34
+ class HDF5VLADataset:
35
+ """
36
+ This class is used to sample episodes from the embodiment dataset
37
+ stored in HDF5 files.
38
+ """
39
+ def __init__(self):
40
+ # The name of your dataset
41
+ self.DATASET_NAME = "agilex"
42
+
43
+ self.data_dir = "data/datasets/rdt-ft-data/demo_1k"
44
+ self.tasks = os.listdir(self.data_dir)
45
+ # Multiple tasks
46
+ self.tasks = ['PickCube-v1', 'StackCube-v1', 'PlugCharger-v1', 'PushCube-v1', 'PegInsertionSide-v1']
47
+ # Load configuration from YAML file
48
+ with open('configs/base.yaml', 'r') as file:
49
+ config = yaml.safe_load(file)
50
+ self.CHUNK_SIZE = config['common']['action_chunk_size']
51
+ self.IMG_HISTORY_SIZE = config['common']['img_history_size']
52
+ self.STATE_DIM = config['common']['state_dim']
53
+
54
+ self.num_episode_per_task = 1000
55
+ self.img = []
56
+ self.state = []
57
+ self.action = []
58
+
59
+ # open the hdf5 files in memory to speed up the data loading
60
+ for task in self.tasks:
61
+ file_path = glob.glob(os.path.join(self.data_dir, task, 'motionplanning', '*.h5'))[0]
62
+ with h5py.File(file_path, "r") as f:
63
+ trajs = f.keys() # traj_0, traj_1,
64
+ # sort by the traj number
65
+ trajs = sorted(trajs, key=lambda x: int(x.split('_')[-1]))
66
+ for traj in trajs:
67
+ # images = f[traj]['obs']['sensor_data']['base_camera']['rgb'][:]
68
+ states = f[traj]['obs']['agent']['qpos'][:]
69
+ actions = f[traj]['actions'][:]
70
+
71
+ self.state.append(states)
72
+ self.action.append(actions)
73
+ # self.img.append(images)
74
+
75
+ self.state_min = np.concatenate(self.state).min(axis=0)
76
+ self.state_max = np.concatenate(self.state).max(axis=0)
77
+ self.action_min = np.concatenate(self.action).min(axis=0)
78
+ self.action_max = np.concatenate(self.action).max(axis=0)
79
+ self.action_std = np.concatenate(self.action).std(axis=0)
80
+ self.action_mean = np.concatenate(self.action).mean(axis=0)
81
+
82
+ self.task2lang = {
83
+ "PegInsertionSide-v1": "Pick up a orange-white peg and insert the orange end into the box with a hole in it.",
84
+ "PickCube-v1": "Grasp a red cube and move it to a target goal position.",
85
+ "StackCube-v1": "Pick up a red cube and stack it on top of a green cube and let go of the cube without it falling.",
86
+ "PlugCharger-v1": "Pick up one of the misplaced shapes on the board/kit and insert it into the correct empty slot.",
87
+ "PushCube-v1": "Push and move a cube to a goal region in front of it."
88
+ }
89
+
90
+ def __len__(self):
91
+ # Assume each file contains 100 episodes
92
+ return len(self.tasks) * self.num_episode_per_task
93
+
94
+ def get_dataset_name(self):
95
+ return self.DATASET_NAME
96
+
97
+ def get_item(self, index=None):
98
+ """
99
+ Get a training sample at a random timestep.
100
+
101
+ Args:
102
+ index (int, optional): The index of the episode.
103
+ If not provided, a random episode will be selected.
104
+ state_only (bool, optional): Whether to return only the state.
105
+ In this way, the sample will contain a complete trajectory rather
106
+ than a single timestep. Defaults to False.
107
+
108
+ Returns:
109
+ sample (dict): A dictionary containing the training sample.
110
+ """
111
+ while True:
112
+ if index is None:
113
+ index = np.random.randint(0, self.__len__())
114
+ valid, sample = self.parse_hdf5_file(index)
115
+ if valid:
116
+ return sample
117
+ else:
118
+ index = np.random.randint(0, self.__len__())
119
+
120
+ def parse_hdf5_file(self, index):
121
+ """
122
+ Parse an HDF5 file to generate a training sample at a random timestep.
123
+
124
+ Args:
125
+ file_path (str): The path to the HDF5 file.
126
+
127
+ Returns:
128
+ valid (bool): Whether the episode is valid.
129
+ dict: A dictionary containing the training sample.
130
+ """
131
+ num_steps = len(self.action[index])
132
+ step_index = np.random.randint(0, num_steps)
133
+ task_index = index // self.num_episode_per_task
134
+ language = self.task2lang[self.tasks[task_index]]
135
+ task_inner_index = index % self.num_episode_per_task
136
+ # Skip these episodes since in the eef version dataset they are invalid.
137
+ if self.tasks[task_index] == 'PegInsertionSide-v1' and task_inner_index > 400:
138
+ return False, None
139
+ proc_index = task_inner_index // 100
140
+ episode_index = task_inner_index % 100
141
+ # images0 = self.img[index]
142
+ # normalize to -1, 1
143
+ states = (self.state[index] - self.state_min) / (self.state_max - self.state_min) * 2 - 1
144
+ states = states[:, :-1] # remove the last state as it is replicate of the -2 state
145
+ actions = (self.action[index] - self.action_min) / (self.action_max - self.action_min) * 2 - 1
146
+
147
+ # Get image history
148
+ start_img_idx = max(0, step_index - self.IMG_HISTORY_SIZE + 1)
149
+ end_img_idx = step_index + 1
150
+ img_history = []
151
+ for i in range(start_img_idx, end_img_idx):
152
+ img_path = os.path.join(self.data_dir, self.tasks[task_index], 'motionplanning', f'{proc_index}', f'{episode_index}', f"{i + 1}.png")
153
+ img = np.array(Image.open(img_path))
154
+ img_history.append(img)
155
+ img_history = np.array(img_history)
156
+ # img_history = images0[start_img_idx:end_img_idx]
157
+ img_valid_len = img_history.shape[0]
158
+
159
+ # Pad images if necessary
160
+ if img_valid_len < self.IMG_HISTORY_SIZE:
161
+ padding = np.tile(img_history[0:1], (self.IMG_HISTORY_SIZE - img_valid_len, 1, 1, 1))
162
+ img_history = np.concatenate([padding, img_history], axis=0)
163
+
164
+ img_history_mask = np.array(
165
+ [False] * (self.IMG_HISTORY_SIZE - img_valid_len) + [True] * img_valid_len
166
+ )
167
+
168
+ # Compute state statistics
169
+ state_std = np.std(states, axis=0)
170
+ state_mean = np.mean(states, axis=0)
171
+ state_norm = np.sqrt(np.mean(states ** 2, axis=0))
172
+
173
+ # Get state and action at the specified timestep
174
+ state = states[step_index: step_index + 1]
175
+ runtime_chunksize = self.CHUNK_SIZE // 4
176
+ action_sequence = actions[step_index: step_index + runtime_chunksize]
177
+ # we use linear interpolation to pad the action sequence
178
+
179
+ # Pad action sequence if necessary
180
+ if action_sequence.shape[0] < runtime_chunksize:
181
+ padding = np.tile(action_sequence[-1:], (runtime_chunksize - action_sequence.shape[0], 1))
182
+ action_sequence = np.concatenate([action_sequence, padding], axis=0)
183
+
184
+ action_sequence = interpolate_action_sequence(action_sequence, self.CHUNK_SIZE)
185
+
186
+ # Fill state and action into unified vectors
187
+ def fill_in_state(values):
188
+ UNI_STATE_INDICES = [
189
+ STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(7)
190
+ ] + [
191
+ STATE_VEC_IDX_MAPPING[f"right_gripper_open"]
192
+ ]
193
+ uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM,))
194
+ uni_vec[..., UNI_STATE_INDICES] = values
195
+ return uni_vec
196
+
197
+ state_indicator = fill_in_state(np.ones_like(state_std))
198
+ state = fill_in_state(state)
199
+ state_std = fill_in_state(state_std)
200
+ state_mean = fill_in_state(state_mean)
201
+ state_norm = fill_in_state(state_norm)
202
+ action_sequence = fill_in_state(action_sequence)
203
+
204
+ # Assemble the meta information
205
+ meta = {
206
+ "dataset_name": self.DATASET_NAME,
207
+ "#steps": num_steps,
208
+ "step_id": step_index,
209
+ "instruction": language
210
+ }
211
+
212
+ # Return the resulting sample
213
+ return True, {
214
+ "meta": meta,
215
+ "state": state,
216
+ "state_std": state_std,
217
+ "state_mean": state_mean,
218
+ "state_norm": state_norm,
219
+ "actions": action_sequence,
220
+ "state_indicator": state_indicator,
221
+ "cam_high": img_history, # Assuming images0 are high-level camera images
222
+ "cam_high_mask": img_history_mask,
223
+ "cam_left_wrist": np.zeros((self.IMG_HISTORY_SIZE, 0, 0, 0)),
224
+ "cam_left_wrist_mask": np.zeros(self.IMG_HISTORY_SIZE, dtype=bool),
225
+ "cam_right_wrist": np.zeros((self.IMG_HISTORY_SIZE, 0, 0, 0)),
226
+ "cam_right_wrist_mask": np.zeros(self.IMG_HISTORY_SIZE, dtype=bool),
227
+ }
228
+
229
+
230
+ if __name__ == "__main__":
231
+ from PIL import Image
232
+
233
+ ds = HDF5VLADataset()
234
+
235
+ json_data = {
236
+ 'state_min': ds.state_min.tolist(),
237
+ 'state_max': ds.state_max.tolist(),
238
+ 'action_min': ds.action_min.tolist(),
239
+ 'action_max': ds.action_max.tolist(),
240
+ }
241
+ print(json_data)
242
+
243
+
data/hdf5_vla_dataset.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import fnmatch
3
+ import json
4
+
5
+ import h5py
6
+ import yaml
7
+ import cv2
8
+ import numpy as np
9
+
10
+ from configs.state_vec import STATE_VEC_IDX_MAPPING
11
+ TABLETOP_6D_INDICES_NAMES = [
12
+ 'left_eef_pos_x','left_eef_pos_y','left_eef_pos_z','left_eef_angle_0','left_eef_angle_1','left_eef_angle_2','left_eef_angle_3','left_eef_angle_4','left_eef_angle_5','left_gripper_open','right_eef_pos_x','right_eef_pos_y','right_eef_pos_z','right_eef_angle_0','right_eef_angle_1','right_eef_angle_2','right_eef_angle_3','right_eef_angle_4','right_eef_angle_5','right_gripper_open']
13
+ TABLETOP_6D_INDICES = [STATE_VEC_IDX_MAPPING[n] for n in TABLETOP_6D_INDICES_NAMES]
14
+
15
+ class TabletopHDF5VLADataset:
16
+ """
17
+ This class is used to sample episodes from the embododiment dataset
18
+ stored in HDF5.
19
+ """
20
+ def __init__(self, task_name) -> None:
21
+ # [Modify] The path to the HDF5 dataset directory
22
+ # Each HDF5 file contains one episode
23
+ dataset_name = task_name
24
+ HDF5_DIR = f"/data5/jellyho/tabletop/{dataset_name}/"
25
+ self.DATASET_NAME = dataset_name
26
+
27
+ self.file_paths = []
28
+ for root, _, files in os.walk(HDF5_DIR):
29
+ for filename in fnmatch.filter(files, '*.hdf5'):
30
+ file_path = os.path.join(root, filename)
31
+ self.file_paths.append(file_path)
32
+
33
+ # Load the config
34
+ with open('configs/base.yaml', 'r') as file:
35
+ config = yaml.safe_load(file)
36
+ self.CHUNK_SIZE = config['common']['action_chunk_size']
37
+ self.IMG_HISORY_SIZE = config['common']['img_history_size']
38
+ self.STATE_DIM = config['common']['state_dim']
39
+
40
+ # Get each episode's len
41
+ episode_lens = []
42
+ for file_path in self.file_paths:
43
+ valid, res = self.parse_hdf5_file_state_only(file_path)
44
+ _len = res['state'].shape[0] if valid else 0
45
+ episode_lens.append(_len)
46
+ self.episode_sample_weights = np.array(episode_lens) / np.sum(episode_lens)
47
+
48
+ def __len__(self):
49
+ return len(self.file_paths)
50
+
51
+ def get_dataset_name(self):
52
+ return self.DATASET_NAME
53
+
54
+ def get_item(self, index: int=None, state_only=False):
55
+ """Get a training sample at a random timestep.
56
+
57
+ Args:
58
+ index (int, optional): the index of the episode.
59
+ If not provided, a random episode will be selected.
60
+ state_only (bool, optional): Whether to return only the state.
61
+ In this way, the sample will contain a complete trajectory rather
62
+ than a single timestep. Defaults to False.
63
+
64
+ Returns:
65
+ sample (dict): a dictionary containing the training sample.
66
+ """
67
+ while True:
68
+ if index is None:
69
+ file_path = np.random.choice(self.file_paths, p=self.episode_sample_weights)
70
+ else:
71
+ file_path = self.file_paths[index]
72
+ valid, sample = self.parse_hdf5_file(file_path) \
73
+ if not state_only else self.parse_hdf5_file_state_only(file_path)
74
+ if valid:
75
+ return sample
76
+ else:
77
+ index = np.random.randint(0, len(self.file_paths))
78
+
79
+ def parse_hdf5_file(self, file_path):
80
+ """[Modify] Parse a hdf5 file to generate a training sample at
81
+ a random timestep.
82
+
83
+ Args:
84
+ file_path (str): the path to the hdf5 file
85
+
86
+ Returns:
87
+ valid (bool): whether the episode is valid, which is useful for filtering.
88
+ If False, this episode will be dropped.
89
+ dict: a dictionary containing the training sample,
90
+ {
91
+ "meta": {
92
+ "dataset_name": str, # the name of your dataset.
93
+ "#steps": int, # the number of steps in the episode,
94
+ # also the total timesteps.
95
+ "instruction": str # the language instruction for this episode.
96
+ },
97
+ "step_id": int, # the index of the sampled step,
98
+ # also the timestep t.
99
+ "state": ndarray, # state[t], (1, STATE_DIM).
100
+ "state_std": ndarray, # std(state[:]), (STATE_DIM,).
101
+ "state_mean": ndarray, # mean(state[:]), (STATE_DIM,).
102
+ "state_norm": ndarray, # norm(state[:]), (STATE_DIM,).
103
+ "actions": ndarray, # action[t:t+CHUNK_SIZE], (CHUNK_SIZE, STATE_DIM).
104
+ "state_indicator", ndarray, # indicates the validness of each dim, (STATE_DIM,).
105
+ "cam_high": ndarray, # external camera image, (IMG_HISORY_SIZE, H, W, 3)
106
+ # or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
107
+ "cam_high_mask": ndarray, # indicates the validness of each timestep, (IMG_HISORY_SIZE,) boolean array.
108
+ # For the first IMAGE_HISTORY_SIZE-1 timesteps, the mask should be False.
109
+ "cam_left_wrist": ndarray, # left wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
110
+ # or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
111
+ "cam_left_wrist_mask": ndarray,
112
+ "cam_right_wrist": ndarray, # right wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
113
+ # or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
114
+ # If only one wrist, make it right wrist, plz.
115
+ "cam_right_wrist_mask": ndarray
116
+ } or None if the episode is invalid.
117
+ """
118
+ with h5py.File(file_path, 'r') as f:
119
+ states = f['observations']['states']['ee_6d_pos'][:]
120
+ actions = f['actions']['ee_6d_pos'][:]
121
+ num_steps = states.shape[0]
122
+ # [Optional] We drop too-short episode
123
+ if num_steps < 20:
124
+ return False, None
125
+
126
+ # We randomly sample a timestep
127
+ step_id = np.random.randint(0, num_steps)
128
+
129
+ # You can also use precomputed language embeddings (recommended)
130
+ if self.DATASET_NAME == 'aloha_box_into_pot_easy':
131
+ instruction = f['observations']['states']['language_instruction'][0].decode('utf-8')
132
+ else:
133
+ instruction = f"lang_embed/{self.DATASET_NAME}.pt"
134
+
135
+ # Assemble the meta
136
+ meta = {
137
+ "dataset_name": self.DATASET_NAME,
138
+ "#steps": num_steps,
139
+ "step_id": step_id,
140
+ "instruction": instruction
141
+ }
142
+
143
+ # Rescale gripper to [0, 1]
144
+ states = states / np.array(
145
+ [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
146
+ )
147
+ actions = actions[step_id:step_id+self.CHUNK_SIZE] / np.array(
148
+ [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
149
+ )
150
+
151
+ # Parse the state and action
152
+ state = states[step_id:step_id+1]
153
+ state_std = np.std(states, axis=0)
154
+ state_mean = np.mean(states, axis=0)
155
+ state_norm = np.sqrt(np.mean(states**2, axis=0))
156
+
157
+ if actions.shape[0] < self.CHUNK_SIZE:
158
+ # Pad the actions using the last action
159
+ actions = np.concatenate([
160
+ actions,
161
+ np.tile(actions[-1:], (self.CHUNK_SIZE-actions.shape[0], 1))
162
+ ], axis=0)
163
+
164
+ # Fill the state/action into the unified vector
165
+ def fill_in_state(values):
166
+ uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM,))
167
+ uni_vec[..., TABLETOP_6D_INDICES] = values
168
+ return uni_vec
169
+ state = fill_in_state(state)
170
+ state_indicator = fill_in_state(np.ones_like(state_std))
171
+ state_std = fill_in_state(state_std)
172
+ state_mean = fill_in_state(state_mean)
173
+ state_norm = fill_in_state(state_norm)
174
+ # If action's format is different from state's,
175
+ # you may implement fill_in_action()
176
+ actions = fill_in_state(actions)
177
+
178
+ # Parse the images
179
+ def parse_img(key):
180
+ imgs = []
181
+ for i in range(max(step_id-self.IMG_HISORY_SIZE+1, 0), step_id+1):
182
+ img = f['observations']['images'][key][i]
183
+ # imgs.append(cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR))
184
+ imgs.append(img)
185
+ # print(imgs)
186
+ imgs = np.stack(imgs)
187
+ if imgs.shape[0] < self.IMG_HISORY_SIZE:
188
+ # Pad the images using the first image
189
+ imgs = np.concatenate([
190
+ np.tile(imgs[:1], (self.IMG_HISORY_SIZE-imgs.shape[0], 1, 1, 1)),
191
+ imgs
192
+ ], axis=0)
193
+ return imgs
194
+ # `cam_high` is the external camera image
195
+ cam_high = parse_img('back')
196
+ # For step_id = first_idx - 1, the valid_len should be one
197
+ valid_len = min(step_id + 1, self.IMG_HISORY_SIZE)
198
+ cam_high_mask = np.array(
199
+ [False] * (self.IMG_HISORY_SIZE - valid_len) + [True] * valid_len
200
+ )
201
+ cam_left_wrist = parse_img('wrist_left')
202
+ cam_left_wrist_mask = cam_high_mask.copy()
203
+ cam_right_wrist = parse_img('wrist_right')
204
+ cam_right_wrist_mask = cam_high_mask.copy()
205
+
206
+ # print(cam_left_wrist is not None, cam_right_wrist is not None, cam_high is not None)
207
+
208
+ # Return the resulting sample
209
+ # For unavailable images, return zero-shape arrays, i.e., (IMG_HISORY_SIZE, 0, 0, 0)
210
+ # E.g., return np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0)) for the key "cam_left_wrist",
211
+ # if the left-wrist camera is unavailable on your robot
212
+ return True, {
213
+ "meta": meta,
214
+ "state": state,
215
+ "state_std": state_std,
216
+ "state_mean": state_mean,
217
+ "state_norm": state_norm,
218
+ "actions": actions,
219
+ "state_indicator": state_indicator,
220
+ "cam_high": cam_high,
221
+ "cam_high_mask": cam_high_mask,
222
+ "cam_left_wrist": cam_left_wrist,
223
+ "cam_left_wrist_mask": cam_left_wrist_mask,
224
+ "cam_right_wrist": cam_right_wrist,
225
+ "cam_right_wrist_mask": cam_right_wrist_mask
226
+ }
227
+
228
+ def parse_hdf5_file_state_only(self, file_path):
229
+ """[Modify] Parse a hdf5 file to generate a state trajectory.
230
+
231
+ Args:
232
+ file_path (str): the path to the hdf5 file
233
+
234
+ Returns:
235
+ valid (bool): whether the episode is valid, which is useful for filtering.
236
+ If False, this episode will be dropped.
237
+ dict: a dictionary containing the training sample,
238
+ {
239
+ "state": ndarray, # state[:], (T, STATE_DIM).
240
+ "action": ndarray, # action[:], (T, STATE_DIM).
241
+ } or None if the episode is invalid.
242
+ """
243
+ with h5py.File(file_path, 'r') as f:
244
+ states = f['observations']['states']['ee_6d_pos'][:]
245
+ actions = f['actions']['ee_6d_pos'][:]
246
+ num_steps = states.shape[0]
247
+
248
+ step_id = np.random.randint(0, num_steps)
249
+
250
+ # Rescale gripper to [0, 1]
251
+ states = states / np.array(
252
+ [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
253
+ )
254
+ actions = actions[step_id:step_id+self.CHUNK_SIZE] / np.array(
255
+ [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
256
+ )
257
+
258
+ # Fill the state/action into the unified vector
259
+ def fill_in_state(values):
260
+ uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM,))
261
+ uni_vec[..., TABLETOP_6D_INDICES] = values
262
+ return uni_vec
263
+ state = fill_in_state(states)
264
+ action = fill_in_state(actions)
265
+
266
+ # Return the resulting sample
267
+ return True, {
268
+ "state": state,
269
+ "action": action
270
+ }
271
+
272
+ class AnubisHDF5VLADataset:
273
+ """
274
+ This class is used to sample episodes from the embododiment dataset
275
+ stored in HDF5.
276
+ """
277
+ def __init__(self, task_name) -> None:
278
+ # [Modify] The path to the HDF5 dataset directory
279
+ # Each HDF5 file contains one episode
280
+ dataset_name = task_name
281
+ HDF5_DIR = f"/data5/jellyho/anubis_hdf5/{dataset_name}/"
282
+ self.DATASET_NAME = dataset_name
283
+
284
+ self.file_paths = []
285
+ for root, _, files in os.walk(HDF5_DIR):
286
+ for filename in fnmatch.filter(files, '*.hdf5'):
287
+ file_path = os.path.join(root, filename)
288
+ self.file_paths.append(file_path)
289
+
290
+ # Load the config
291
+ with open('configs/base.yaml', 'r') as file:
292
+ config = yaml.safe_load(file)
293
+ self.CHUNK_SIZE = config['common']['action_chunk_size']
294
+ self.IMG_HISORY_SIZE = config['common']['img_history_size']
295
+ self.STATE_DIM = config['common']['state_dim']
296
+
297
+ # Get each episode's len
298
+ episode_lens = []
299
+ for file_path in self.file_paths:
300
+ valid, res = self.parse_hdf5_file_state_only(file_path)
301
+ _len = res['state'].shape[0] if valid else 0
302
+ episode_lens.append(_len)
303
+ self.episode_sample_weights = np.array(episode_lens) / np.sum(episode_lens)
304
+
305
+ def __len__(self):
306
+ return len(self.file_paths)
307
+
308
+ def get_dataset_name(self):
309
+ return self.DATASET_NAME
310
+
311
+ def get_item(self, index: int=None, state_only=False):
312
+ """Get a training sample at a random timestep.
313
+
314
+ Args:
315
+ index (int, optional): the index of the episode.
316
+ If not provided, a random episode will be selected.
317
+ state_only (bool, optional): Whether to return only the state.
318
+ In this way, the sample will contain a complete trajectory rather
319
+ than a single timestep. Defaults to False.
320
+
321
+ Returns:
322
+ sample (dict): a dictionary containing the training sample.
323
+ """
324
+ while True:
325
+ if index is None:
326
+ file_path = np.random.choice(self.file_paths, p=self.episode_sample_weights)
327
+ else:
328
+ file_path = self.file_paths[index]
329
+ valid, sample = self.parse_hdf5_file(file_path) \
330
+ if not state_only else self.parse_hdf5_file_state_only(file_path)
331
+ if valid:
332
+ return sample
333
+ else:
334
+ index = np.random.randint(0, len(self.file_paths))
335
+
336
+ def parse_hdf5_file(self, file_path):
337
+ """[Modify] Parse a hdf5 file to generate a training sample at
338
+ a random timestep.
339
+
340
+ Args:
341
+ file_path (str): the path to the hdf5 file
342
+
343
+ Returns:
344
+ valid (bool): whether the episode is valid, which is useful for filtering.
345
+ If False, this episode will be dropped.
346
+ dict: a dictionary containing the training sample,
347
+ {
348
+ "meta": {
349
+ "dataset_name": str, # the name of your dataset.
350
+ "#steps": int, # the number of steps in the episode,
351
+ # also the total timesteps.
352
+ "instruction": str # the language instruction for this episode.
353
+ },
354
+ "step_id": int, # the index of the sampled step,
355
+ # also the timestep t.
356
+ "state": ndarray, # state[t], (1, STATE_DIM).
357
+ "state_std": ndarray, # std(state[:]), (STATE_DIM,).
358
+ "state_mean": ndarray, # mean(state[:]), (STATE_DIM,).
359
+ "state_norm": ndarray, # norm(state[:]), (STATE_DIM,).
360
+ "actions": ndarray, # action[t:t+CHUNK_SIZE], (CHUNK_SIZE, STATE_DIM).
361
+ "state_indicator", ndarray, # indicates the validness of each dim, (STATE_DIM,).
362
+ "cam_high": ndarray, # external camera image, (IMG_HISORY_SIZE, H, W, 3)
363
+ # or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
364
+ "cam_high_mask": ndarray, # indicates the validness of each timestep, (IMG_HISORY_SIZE,) boolean array.
365
+ # For the first IMAGE_HISTORY_SIZE-1 timesteps, the mask should be False.
366
+ "cam_left_wrist": ndarray, # left wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
367
+ # or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
368
+ "cam_left_wrist_mask": ndarray,
369
+ "cam_right_wrist": ndarray, # right wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
370
+ # or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
371
+ # If only one wrist, make it right wrist, plz.
372
+ "cam_right_wrist_mask": ndarray
373
+ } or None if the episode is invalid.
374
+ """
375
+ with h5py.File(file_path, 'r') as f:
376
+ states = f['observation']['eef_pose'][:]
377
+ actions = f['action']['eef_pose'][:]
378
+ num_steps = states.shape[0]
379
+ # [Optional] We drop too-short episode
380
+ if num_steps < 20:
381
+ return False, None
382
+
383
+ # We randomly sample a timestep
384
+ step_id = np.random.randint(0, num_steps)
385
+
386
+ # You can also use precomputed language embeddings (recommended)
387
+ if self.DATASET_NAME == 'aloha_box_into_pot_easy':
388
+ instruction = f['observations']['states']['language_instruction'][0].decode('utf-8')
389
+ else:
390
+ instruction = f"lang_embed/{self.DATASET_NAME}.pt"
391
+
392
+ # Assemble the meta
393
+ meta = {
394
+ "dataset_name": self.DATASET_NAME,
395
+ "#steps": num_steps,
396
+ "step_id": step_id,
397
+ "instruction": instruction
398
+ }
399
+
400
+ # Rescale gripper to [0, 1]
401
+ states = states / np.array(
402
+ [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
403
+ )
404
+ actions = actions[step_id:step_id+self.CHUNK_SIZE] / np.array(
405
+ [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
406
+ )
407
+
408
+ # Parse the state and action
409
+ state = states[step_id:step_id+1]
410
+ state_std = np.std(states, axis=0)
411
+ state_mean = np.mean(states, axis=0)
412
+ state_norm = np.sqrt(np.mean(states**2, axis=0))
413
+
414
+ if actions.shape[0] < self.CHUNK_SIZE:
415
+ # Pad the actions using the last action
416
+ actions = np.concatenate([
417
+ actions,
418
+ np.tile(actions[-1:], (self.CHUNK_SIZE-actions.shape[0], 1))
419
+ ], axis=0)
420
+
421
+ # Fill the state/action into the unified vector
422
+ def fill_in_state(values):
423
+ uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM,))
424
+ uni_vec[..., TABLETOP_6D_INDICES] = values
425
+ return uni_vec
426
+ state = fill_in_state(state)
427
+ state_indicator = fill_in_state(np.ones_like(state_std))
428
+ state_std = fill_in_state(state_std)
429
+ state_mean = fill_in_state(state_mean)
430
+ state_norm = fill_in_state(state_norm)
431
+ # If action's format is different from state's,
432
+ # you may implement fill_in_action()
433
+ actions = fill_in_state(actions)
434
+
435
+ # Parse the images
436
+ def parse_img(key):
437
+ imgs = []
438
+ for i in range(max(step_id-self.IMG_HISORY_SIZE+1, 0), step_id+1):
439
+ img = f['observation'][key][i]
440
+ # imgs.append(cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR))
441
+ imgs.append(img)
442
+ # print(imgs)
443
+ imgs = np.stack(imgs)
444
+ if imgs.shape[0] < self.IMG_HISORY_SIZE:
445
+ # Pad the images using the first image
446
+ imgs = np.concatenate([
447
+ np.tile(imgs[:1], (self.IMG_HISORY_SIZE-imgs.shape[0], 1, 1, 1)),
448
+ imgs
449
+ ], axis=0)
450
+ return imgs
451
+ # `cam_high` is the external camera image
452
+ cam_high = parse_img('agentview_image')
453
+ # For step_id = first_idx - 1, the valid_len should be one
454
+ valid_len = min(step_id + 1, self.IMG_HISORY_SIZE)
455
+ cam_high_mask = np.array(
456
+ [False] * (self.IMG_HISORY_SIZE - valid_len) + [True] * valid_len
457
+ )
458
+ cam_left_wrist = parse_img('wrist_left_image')
459
+ cam_left_wrist_mask = cam_high_mask.copy()
460
+ cam_right_wrist = parse_img('wrist_right_image')
461
+ cam_right_wrist_mask = cam_high_mask.copy()
462
+
463
+ # print(cam_left_wrist is not None, cam_right_wrist is not None, cam_high is not None)
464
+
465
+ # Return the resulting sample
466
+ # For unavailable images, return zero-shape arrays, i.e., (IMG_HISORY_SIZE, 0, 0, 0)
467
+ # E.g., return np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0)) for the key "cam_left_wrist",
468
+ # if the left-wrist camera is unavailable on your robot
469
+ return True, {
470
+ "meta": meta,
471
+ "state": state,
472
+ "state_std": state_std,
473
+ "state_mean": state_mean,
474
+ "state_norm": state_norm,
475
+ "actions": actions,
476
+ "state_indicator": state_indicator,
477
+ "cam_high": cam_high,
478
+ "cam_high_mask": cam_high_mask,
479
+ "cam_left_wrist": cam_left_wrist,
480
+ "cam_left_wrist_mask": cam_left_wrist_mask,
481
+ "cam_right_wrist": cam_right_wrist,
482
+ "cam_right_wrist_mask": cam_right_wrist_mask
483
+ }
484
+
485
+ def parse_hdf5_file_state_only(self, file_path):
486
+ """[Modify] Parse a hdf5 file to generate a state trajectory.
487
+
488
+ Args:
489
+ file_path (str): the path to the hdf5 file
490
+
491
+ Returns:
492
+ valid (bool): whether the episode is valid, which is useful for filtering.
493
+ If False, this episode will be dropped.
494
+ dict: a dictionary containing the training sample,
495
+ {
496
+ "state": ndarray, # state[:], (T, STATE_DIM).
497
+ "action": ndarray, # action[:], (T, STATE_DIM).
498
+ } or None if the episode is invalid.
499
+ """
500
+ with h5py.File(file_path, 'r') as f:
501
+ states = f['observation']['eef_pose'][:]
502
+ actions = f['action']['eef_pose'][:]
503
+ num_steps = states.shape[0]
504
+
505
+ step_id = np.random.randint(0, num_steps)
506
+
507
+ # Rescale gripper to [0, 1]
508
+ states = states / np.array(
509
+ [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
510
+ )
511
+ actions = actions[step_id:step_id+self.CHUNK_SIZE] / np.array(
512
+ [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
513
+ )
514
+
515
+ # Fill the state/action into the unified vector
516
+ def fill_in_state(values):
517
+ uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM,))
518
+ uni_vec[..., TABLETOP_6D_INDICES] = values
519
+ return uni_vec
520
+ state = fill_in_state(states)
521
+ action = fill_in_state(actions)
522
+
523
+ # Return the resulting sample
524
+ return True, {
525
+ "state": state,
526
+ "action": action
527
+ }
528
+
529
+ if __name__ == "__main__":
530
+ ds = TabletopHDF5VLADataset()
531
+ for i in range(len(ds)):
532
+ print(f"Processing episode {i}/{len(ds)}...")
533
+ ds.get_item(i)
data/preprocess.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import tensorflow as tf
4
+ import yaml
5
+
6
+ from data.preprocess_scripts import *
7
+ from configs.state_vec import STATE_VEC_IDX_MAPPING, STATE_VEC_LEN
8
+ from data.utils import capitalize_and_period
9
+
10
+ # The dataset without state
11
+ DATASET_NAMES_NO_STATE = [
12
+ 'nyu_door_opening_surprising_effectiveness',
13
+ "usc_cloth_sim_converted_externally_to_rlds",
14
+ 'cmu_franka_exploration_dataset_converted_externally_to_rlds',
15
+ 'imperialcollege_sawyer_wrist_cam'
16
+ ]
17
+
18
+ # Read the image keys of each dataset
19
+ with open('configs/dataset_img_keys.json', 'r') as file:
20
+ IMAGE_KEYS = json.load(file)
21
+ # Read the config
22
+ with open('configs/base.yaml', 'r') as file:
23
+ config = yaml.safe_load(file)
24
+
25
+
26
+ def assemble_state_vec(arm_concat: tf.Tensor, arm_format: str,
27
+ base_concat=None, base_format=None) -> tf.Tensor:
28
+ """
29
+ Assemble the state/action vector from the arm and base.
30
+ """
31
+ state_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32)
32
+ mask_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32)
33
+
34
+ # Assemble the arm state
35
+ arm_concat = tf.cast(arm_concat, tf.float32)
36
+ arm_format = arm_format.split(',')
37
+ # Use the scatter_nd to avoid the duplicate indices
38
+ state_vec = tf.tensor_scatter_nd_update(
39
+ state_vec,
40
+ [[STATE_VEC_IDX_MAPPING[name]] for name in arm_format],
41
+ arm_concat
42
+ )
43
+ mask_vec = tf.tensor_scatter_nd_update(
44
+ mask_vec,
45
+ [[STATE_VEC_IDX_MAPPING[name]] for name in arm_format],
46
+ tf.ones(len(arm_format), dtype=tf.float32)
47
+ )
48
+
49
+ # Assemble the base state if exists
50
+ if base_concat is not None:
51
+ base_concat = tf.cast(base_concat, tf.float32)
52
+ base_format = base_format.split(',')
53
+ state_vec = tf.tensor_scatter_nd_update(
54
+ state_vec,
55
+ [[STATE_VEC_IDX_MAPPING[name]] for name in base_format],
56
+ base_concat
57
+ )
58
+ mask_vec = tf.tensor_scatter_nd_update(
59
+ mask_vec,
60
+ [[STATE_VEC_IDX_MAPPING[name]] for name in base_format],
61
+ tf.ones(len(base_format), dtype=tf.float32)
62
+ )
63
+ return state_vec, mask_vec
64
+
65
+
66
+ @tf.autograph.experimental.do_not_convert
67
+ def _generate_json_state_agilex(episode: dict, dataset_name: str):
68
+ """
69
+ Generate the json dict and state for a given episode.
70
+ """
71
+ # Load some constants from the config
72
+ IMG_HISTORY_SIZE = config['common']['img_history_size']
73
+ if IMG_HISTORY_SIZE < 1:
74
+ raise ValueError("Config `img_history_size` must be at least 1.")
75
+ ACTION_CHUNK_SIZE = config['common']['action_chunk_size']
76
+ if ACTION_CHUNK_SIZE < 1:
77
+ raise ValueError("Config `action_chunk_size` must be at least 1.")
78
+
79
+ # Initialize the episode_metadata
80
+ episode_metadata = {
81
+ 'dataset_name': dataset_name,
82
+ '#steps': 0,
83
+ 'instruction': None
84
+ }
85
+
86
+ # Check whether this episode has an 'END'
87
+ base_act = None
88
+ last_base_act = None
89
+ episode_states = []
90
+ episode_acts = []
91
+ episode_masks = []
92
+ has_base = None
93
+ for step_id, step in enumerate(iter(episode['steps'])):
94
+ # Parse the action
95
+ action = step['action']
96
+ if has_base is None:
97
+ has_base = 'base_concat' in action
98
+ if has_base:
99
+ base_act = action['base_concat']
100
+
101
+ # Parse the state
102
+ state = step['observation']
103
+
104
+ arm_format = state['format'].numpy().decode('utf-8')
105
+ base_format = None
106
+ if has_base:
107
+ act_format = action['format'].numpy().decode('utf-8')
108
+ base_formate_idx = act_format.find('base')
109
+ base_format = act_format[base_formate_idx:]
110
+
111
+ arm_state = state['arm_concat']
112
+ base_state = None
113
+ if has_base:
114
+ if last_base_act is None:
115
+ base_state = base_act * 0
116
+ else:
117
+ base_state = last_base_act
118
+ last_base_act = base_act
119
+
120
+ # Assemble the state vector
121
+ state_vec, mask_vec = assemble_state_vec(
122
+ arm_state, arm_format, base_state, base_format)
123
+
124
+
125
+ act_vec, mask_vec = assemble_state_vec(
126
+ action['arm_concat'], arm_format, base_state, base_format
127
+ )
128
+
129
+ episode_states.append(state_vec)
130
+ episode_masks.append(mask_vec)
131
+ episode_acts.append(act_vec)
132
+
133
+ # Parse the task instruction
134
+ instr = step['observation']['natural_language_instruction']
135
+ instr = instr.numpy().decode('utf-8')
136
+ instr = capitalize_and_period(instr)
137
+
138
+ # Write to the episode_metadata
139
+ if episode_metadata['instruction'] is None:
140
+ episode_metadata['instruction'] = instr
141
+
142
+ episode_metadata['#steps'] = step_id
143
+
144
+ episode_states = tf.stack(episode_states)
145
+ episode_masks = tf.stack(episode_masks)
146
+ episode_acts = tf.stack(episode_acts)
147
+
148
+ return episode_metadata, episode_states, episode_masks, episode_acts
149
+
150
+
151
+ @tf.autograph.experimental.do_not_convert
152
+ def _generate_json_state(episode: dict, dataset_name: str):
153
+ """
154
+ Generate the json dict and state for a given episode.
155
+ """
156
+ # Load some constants from the config
157
+ IMG_HISTORY_SIZE = config['common']['img_history_size']
158
+ if IMG_HISTORY_SIZE < 1:
159
+ raise ValueError("Config `img_history_size` must be at least 1.")
160
+ ACTION_CHUNK_SIZE = config['common']['action_chunk_size']
161
+ if ACTION_CHUNK_SIZE < 1:
162
+ raise ValueError("Config `action_chunk_size` must be at least 1.")
163
+
164
+ # Initialize the episode_metadata
165
+ episode_metadata = {
166
+ 'dataset_name': dataset_name,
167
+ '#steps': 0,
168
+ 'instruction': None
169
+ }
170
+
171
+ # Check whether this episode has an 'END'
172
+ base_act = None
173
+ last_base_act = None
174
+ episode_states = []
175
+ episode_masks = []
176
+ has_base = None
177
+ for step_id, step in enumerate(iter(episode['steps'])):
178
+ # Parse the action
179
+ action = step['action']
180
+ if has_base is None:
181
+ has_base = 'base_concat' in action
182
+ if has_base:
183
+ base_act = action['base_concat']
184
+
185
+ # Parse the state
186
+ state = step['observation']
187
+
188
+ arm_format = state['format'].numpy().decode('utf-8')
189
+ base_format = None
190
+ if has_base:
191
+ act_format = action['format'].numpy().decode('utf-8')
192
+ base_formate_idx = act_format.find('base')
193
+ base_format = act_format[base_formate_idx:]
194
+
195
+ arm_state = state['arm_concat']
196
+ base_state = None
197
+ if has_base:
198
+ if last_base_act is None:
199
+ base_state = base_act * 0
200
+ else:
201
+ base_state = last_base_act
202
+ last_base_act = base_act
203
+
204
+ # Assemble the state vector
205
+ state_vec, mask_vec = assemble_state_vec(
206
+ arm_state, arm_format, base_state, base_format)
207
+
208
+ episode_states.append(state_vec)
209
+ episode_masks.append(mask_vec)
210
+
211
+ # Parse the task instruction
212
+ instr = step['observation']['natural_language_instruction']
213
+ instr = instr.numpy().decode('utf-8')
214
+ instr = capitalize_and_period(instr)
215
+
216
+ # Write to the episode_metadata
217
+ if episode_metadata['instruction'] is None:
218
+ episode_metadata['instruction'] = instr
219
+
220
+ episode_metadata['#steps'] = step_id
221
+ episode_states = tf.stack(episode_states)
222
+ episode_masks = tf.stack(episode_masks)
223
+
224
+ return episode_metadata, episode_states, episode_masks
225
+
226
+
227
+ @tf.autograph.experimental.do_not_convert
228
+ def _generate_json_state_nostate_ds(episode: dict, dataset_name: str):
229
+ """
230
+ Generate the json dict and state for an episode in the dataset without state.
231
+ If not state, we use the last action as current state.
232
+ """
233
+ # Load some constants from the config
234
+ IMG_HISTORY_SIZE = config['common']['img_history_size']
235
+ if IMG_HISTORY_SIZE < 1:
236
+ raise ValueError("Config `img_history_size` must be at least 1.")
237
+ ACTION_CHUNK_SIZE = config['common']['action_chunk_size']
238
+ if ACTION_CHUNK_SIZE < 1:
239
+ raise ValueError("Config `action_chunk_size` must be at least 1.")
240
+
241
+ # Initialize the episode_metadata
242
+ episode_metadata = {
243
+ 'dataset_name': dataset_name,
244
+ '#steps': 0,
245
+ 'instruction': None
246
+ }
247
+
248
+ last_base_act = None
249
+ last_arm_act = None
250
+ episode_states = []
251
+ episode_masks = []
252
+ has_base = None
253
+ for step_id, step in enumerate(iter(episode['steps'])):
254
+ # Parse the action
255
+ action = step['action']
256
+ if has_base is None:
257
+ has_base = 'base_concat' in action
258
+ if has_base:
259
+ base_act = action['base_concat']
260
+ if last_base_act is None:
261
+ last_base_act = base_act * 0 # Initialize
262
+
263
+ # Parse the arm action
264
+ arm_act = action['arm_concat']
265
+ if last_arm_act is None:
266
+ last_arm_act = arm_act * 0 # Initialize
267
+
268
+ # Parse the act format
269
+ # Action format as the state format
270
+ act_format = action['format'].numpy().decode('utf-8')
271
+
272
+ # Assemble the state vector
273
+ if has_base:
274
+ last_act_concat = tf.concat([last_arm_act, last_base_act], axis=0)
275
+ else:
276
+ last_act_concat = last_arm_act
277
+ state_vec, mask_vec = assemble_state_vec(
278
+ last_act_concat, act_format)
279
+
280
+ episode_states.append(state_vec)
281
+ episode_masks.append(mask_vec)
282
+
283
+ # Parse the task instruction
284
+ instr = step['observation']['natural_language_instruction']
285
+ instr = instr.numpy().decode('utf-8')
286
+ instr = capitalize_and_period(instr)
287
+
288
+ # Write to the episode_metadata
289
+ if episode_metadata['instruction'] is None:
290
+ episode_metadata['instruction'] = instr
291
+
292
+ # Update the last_arm_act and last_base_act
293
+ last_arm_act = arm_act
294
+ if has_base:
295
+ last_base_act = base_act
296
+
297
+ episode_metadata['#steps'] = step_id
298
+ episode_states = tf.stack(episode_states)
299
+ episode_masks = tf.stack(episode_masks)
300
+
301
+ return episode_metadata, episode_states, episode_masks
302
+
303
+
304
+ @tf.autograph.experimental.do_not_convert
305
+ def generate_json_state(episode: dict, dataset_name: str):
306
+ """
307
+ Generate the json dict and state for an episode.
308
+ """
309
+ if isinstance(dataset_name, tf.Tensor):
310
+ dataset_name = dataset_name.numpy().decode('utf-8')
311
+
312
+ # Process each step in the episode
313
+ episode['steps'] = episode['steps'].map(
314
+ globals()[dataset_name].process_step,
315
+ )
316
+
317
+ if dataset_name == "agilex":
318
+ return _generate_json_state_agilex(episode, dataset_name)
319
+
320
+ if dataset_name in DATASET_NAMES_NO_STATE:
321
+ return _generate_json_state_nostate_ds(episode, dataset_name)
322
+
323
+ return _generate_json_state(episode, dataset_name)
data/preprocess_scripts/__init__.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import fractal20220817_data
2
+ from . import bridge
3
+ from . import jaco_play
4
+ from . import nyu_door_opening_surprising_effectiveness
5
+ from . import taco_play
6
+ from . import berkeley_cable_routing
7
+ from . import roboturk
8
+ from . import viola
9
+ from . import berkeley_autolab_ur5
10
+ from . import toto
11
+ from . import columbia_cairlab_pusht_real
12
+ from . import stanford_kuka_multimodal_dataset_converted_externally_to_rlds
13
+ from . import nyu_rot_dataset_converted_externally_to_rlds
14
+ from . import austin_buds_dataset_converted_externally_to_rlds
15
+ from . import nyu_franka_play_dataset_converted_externally_to_rlds
16
+ from . import cmu_franka_exploration_dataset_converted_externally_to_rlds
17
+ from . import kuka
18
+ from . import utokyo_xarm_bimanual_converted_externally_to_rlds
19
+ from . import maniskill_dataset_converted_externally_to_rlds
20
+ from . import stanford_hydra_dataset_converted_externally_to_rlds
21
+ from . import ucsd_kitchen_dataset_converted_externally_to_rlds
22
+ from . import ucsd_pick_and_place_dataset_converted_externally_to_rlds
23
+ from . import austin_sailor_dataset_converted_externally_to_rlds
24
+ from . import austin_sirius_dataset_converted_externally_to_rlds
25
+ from . import bc_z
26
+ from . import usc_cloth_sim_converted_externally_to_rlds
27
+ from . import utokyo_pr2_opening_fridge_converted_externally_to_rlds
28
+ from . import utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds
29
+ from . import utokyo_xarm_pick_and_place_converted_externally_to_rlds
30
+ from . import berkeley_mvp_converted_externally_to_rlds
31
+ from . import berkeley_rpt_converted_externally_to_rlds
32
+ from . import kaist_nonprehensile_converted_externally_to_rlds
33
+ from . import stanford_mask_vit_converted_externally_to_rlds
34
+ from . import tokyo_u_lsmo_converted_externally_to_rlds
35
+ from . import dlr_sara_pour_converted_externally_to_rlds
36
+ from . import dlr_sara_grid_clamp_converted_externally_to_rlds
37
+ from . import dlr_edan_shared_control_converted_externally_to_rlds
38
+ from . import asu_table_top_converted_externally_to_rlds
39
+ from . import stanford_robocook_converted_externally_to_rlds
40
+ from . import roboturk_real_laundrylayout
41
+ from . import roboturk_real_towercreation
42
+ from . import roboturk_real_objectsearch
43
+ from . import robomimic_lift_ph
44
+ from . import robomimic_can_ph
45
+ from . import robomimic_square_ph
46
+ from . import robomimic_transport_ph
47
+ from . import robomimic_tool_hang_ph
48
+ from . import eth_agent_affordances
49
+ from . import imperialcollege_sawyer_wrist_cam
50
+ from . import iamlab_cmu_pickup_insert_converted_externally_to_rlds
51
+ from . import uiuc_d3field
52
+ from . import utaustin_mutex
53
+ from . import berkeley_fanuc_manipulation
54
+ from . import cmu_play_fusion
55
+ from . import cmu_stretch
56
+ from . import berkeley_gnm_recon
57
+ from . import berkeley_gnm_cory_hall
58
+ from . import berkeley_gnm_sac_son
59
+ from . import language_table
60
+ from . import furniture_bench_dataset_converted_externally_to_rlds
61
+ from . import robo_net
62
+ from . import bridgev2
63
+ from . import aloha_mobile
64
+ from . import aloha_static
65
+ from . import droid
66
+ from . import fmb
67
+ from . import dobbe
68
+ from . import qut_dexterous_manpulation
69
+ from . import roboset
70
+ from . import agilex
71
+ from . import rh20t
72
+ from . import calvin
73
+ from . import aloha_dish_drainer, aloha_handover_box, aloha_shoes_table, aloha_lift_box, aloha_box_into_pot
data/preprocess_scripts/aloha_shoes_table.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, euler_to_rotation_matrix, rotation_matrix_to_ortho6d
4
+
5
+
6
+ def process_step(step: dict) -> dict:
7
+ """
8
+ Unify the action format and clean the task instruction.
9
+
10
+ DO NOT use python list, use tf.TensorArray instead.
11
+ """
12
+ # Convert raw action to our action
13
+ action_dict = step['action']
14
+ # Concatenate the action
15
+ step['action'] = {}
16
+ action = step['action']
17
+ action['arm_concat'] = action_dict['ee_6d_pos']
18
+
19
+ # Write the action format
20
+ action['format'] = tf.constant(
21
+ "left_eef_pos_x,left_eef_pos_y,left_eef_pos_z,left_eef_angle_0,left_eef_angle_1,left_eef_angle_2,left_eef_angle_3,left_eef_angle_4,left_eef_angle_5,left_gripper_open,right_eef_pos_x,right_eef_pos_y,right_eef_pos_z,right_eef_angle_0,right_eef_angle_1,right_eef_angle_2,right_eef_angle_3,right_eef_angle_4,right_eef_angle_5,right_gripper_open"
22
+ )
23
+
24
+ # Convert raw state to our state
25
+ # Robot state
26
+ state_dict = step['observation']['state']
27
+ state = {}
28
+ state['arm_concat'] = state_dict
29
+
30
+ # Write the state format
31
+ state['format'] = tf.constant(
32
+ "left_eef_pos_x,left_eef_pos_y,left_eef_pos_z,left_eef_angle_0,left_eef_angle_1,left_eef_angle_2,left_eef_angle_3,left_eef_angle_4,left_eef_angle_5,left_gripper_open,right_eef_pos_x,right_eef_pos_y,right_eef_pos_z,right_eef_angle_0,right_eef_angle_1,right_eef_angle_2,right_eef_angle_3,right_eef_angle_4,right_eef_angle_5,right_gripper_open"
33
+ )
34
+ # Clean the task instruction
35
+ # Define the replacements (old, new) as a dictionary
36
+ replacements = {
37
+ '_': ' ',
38
+ '1f': ' ',
39
+ '4f': ' ',
40
+ '-': ' ',
41
+ '50': ' ',
42
+ '55': ' ',
43
+ '56': ' ',
44
+
45
+ }
46
+ instr = step['language_instruction']
47
+ # instr = clean_task_instruction(instr, replacements)
48
+ step['observation'] = state
49
+ step['observation']['natural_language_instruction'] = instr
50
+
51
+ return step
52
+
53
+
54
+ if __name__ == "__main__":
55
+ pass
data/preprocess_scripts/austin_buds_dataset_converted_externally_to_rlds.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, euler_to_quaternion, rotation_matrix_to_ortho6d
4
+
5
+ def process_step(step: dict) -> dict:
6
+ """
7
+ Unify the action format and clean the task instruction.
8
+
9
+ DO NOT use python list, use tf.TensorArray instead.
10
+ """
11
+ # Convert raw action to our action
12
+
13
+ origin_action = step['action']
14
+ step['action']={}
15
+ action=step['action']
16
+ action['terminate'] = step['is_terminal']
17
+
18
+ eef_delta_pos = origin_action[:3]
19
+ eef_ang=origin_action[3:6]
20
+ eef_ang = euler_to_quaternion(eef_ang)
21
+ # gripper_open: -1-open, 1-closed
22
+ grip_open=tf.where(tf.equal(origin_action[6:],tf.constant(-1.0)),tf.constant(1.0),tf.constant(0.0))
23
+
24
+ # No base found
25
+
26
+ # Concatenate the action
27
+ action['arm_concat'] = tf.concat([eef_delta_pos,eef_ang,grip_open],axis=0)
28
+
29
+ # Write the action format
30
+ action['format'] = tf.constant(
31
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
32
+
33
+ # Convert raw state to our state
34
+ state = step['observation']
35
+ # Concatenate the state
36
+ eef_mat = tf.transpose(tf.reshape(state['state'][8:], (4, 4)))
37
+ eef_pos = eef_mat[:3, 3]
38
+ rotaion_matrix = eef_mat[:3, :3]
39
+ eef_ang = rotation_matrix_to_ortho6d(rotaion_matrix)
40
+ joint_pos = state['state'][:7]
41
+ grip_open = state['state'][7:8] * 12.5 # rescale to [0, 1]
42
+ state['arm_concat'] = tf.concat([joint_pos,grip_open,eef_pos,eef_ang],axis=0)
43
+
44
+ # Write the state format
45
+ state['format'] = tf.constant(
46
+ "arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,gripper_joint_0_pos,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
47
+
48
+ # Clean the task instruction
49
+ # Define the replacements (old, new) as a dictionary
50
+ replacements = {
51
+ '_': ' ',
52
+ '1f': ' ',
53
+ '4f': ' ',
54
+ '-': ' ',
55
+ '50': ' ',
56
+ '55': ' ',
57
+ '56': ' ',
58
+
59
+ }
60
+ instr = step['language_instruction']
61
+ instr = clean_task_instruction(instr, replacements)
62
+ step['observation']['natural_language_instruction'] = instr
63
+
64
+ return step
65
+
66
+
67
+ if __name__ == "__main__":
68
+ import tensorflow_datasets as tfds
69
+ from data.utils import dataset_to_path
70
+
71
+ DATASET_DIR = 'data/datasets/openx_embod'
72
+ DATASET_NAME = 'austin_buds_dataset_converted_externally_to_rlds'
73
+ # Load the dataset
74
+ dataset = tfds.builder_from_directory(
75
+ builder_dir=dataset_to_path(
76
+ DATASET_NAME, DATASET_DIR))
77
+ dataset = dataset.as_dataset(split='all')
78
+
79
+ # Inspect the dataset
80
+ for episode in dataset:
81
+ for step in episode['steps']:
82
+ print(step)
data/preprocess_scripts/berkeley_autolab_ur5.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, euler_to_quaternion, \
4
+ quaternion_to_rotation_matrix, rotation_matrix_to_ortho6d
5
+
6
+
7
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
8
+ """
9
+ Convert terminate action to a boolean, where True means terminate.
10
+ """
11
+ return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
12
+
13
+
14
+ def process_step(step: dict) -> dict:
15
+ """
16
+ Unify the action format and clean the task instruction.
17
+
18
+ DO NOT use python list, use tf.TensorArray instead.
19
+ """
20
+ # Convert raw action to our action
21
+ action = step['action']
22
+ action['terminate'] = terminate_act_to_bool(action['terminate_episode'])
23
+ eef_delta_pos = action['world_vector']
24
+ eef_ang = action['rotation_delta']
25
+ eef_ang = euler_to_quaternion(eef_ang)
26
+
27
+ # Ignore action['gripper_open']: 1 if close gripper, -1 if open gripper, 0 if no change.
28
+
29
+ # No base found
30
+
31
+ # Concatenate the action
32
+ arm_action = tf.concat([eef_delta_pos, eef_ang], axis=0)
33
+ action['arm_concat'] = arm_action
34
+
35
+ # Write the action format
36
+ action['format'] = tf.constant(
37
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w")
38
+
39
+ # Convert raw state to our state
40
+ state = step['observation']
41
+ # state['robot_state']:[joint0, joint1, joint2, joint3, joint4, joint5, x,y,z, qx,qy,qz,qw, gripper_is_closed, action_blocked]
42
+ robot_state = state['robot_state']
43
+ joint_pos=robot_state[:6]
44
+ eef_pos = robot_state[6:9]
45
+ eef_quat = robot_state[9:13]
46
+ eef_ang = quaternion_to_rotation_matrix(eef_quat)
47
+ eef_ang = rotation_matrix_to_ortho6d(eef_ang)
48
+ # gripper_is_closed is binary: 0 = fully open; 1 = fully closed
49
+ grip_closed = robot_state[13:14]
50
+ grip_open= 1-grip_closed
51
+ # action_blocked is binary: 0 = not blocked; 1 = blocked
52
+ # action_blocked = robot_state[14:15]
53
+
54
+ # Concatenate the state
55
+ state['arm_concat'] = tf.concat([joint_pos, grip_open,eef_pos,eef_ang], axis=0)
56
+
57
+ # Write the state format
58
+ state['format'] = tf.constant(
59
+ "arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,gripper_open,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
60
+
61
+ # Clean the task instruction
62
+ # Define the replacements (old, new) as a dictionary
63
+ replacements = {
64
+ '_': ' ',
65
+ '1f': ' ',
66
+ '4f': ' ',
67
+ '-': ' ',
68
+ '50': ' ',
69
+ '55': ' ',
70
+ '56': ' ',
71
+
72
+ }
73
+ instr = step['observation']['natural_language_instruction']
74
+ instr = clean_task_instruction(instr, replacements)
75
+ step['observation']['natural_language_instruction'] = instr
76
+
77
+ return step
78
+
79
+
80
+ if __name__ == "__main__":
81
+ import tensorflow_datasets as tfds
82
+ from data.utils import dataset_to_path
83
+
84
+ DATASET_DIR = 'data/datasets/openx_embod'
85
+ DATASET_NAME = 'berkeley_autolab_ur5'
86
+ # Load the dataset
87
+ dataset = tfds.builder_from_directory(
88
+ builder_dir=dataset_to_path(
89
+ DATASET_NAME, DATASET_DIR))
90
+ dataset = dataset.as_dataset(split='all')
91
+
92
+ # Inspect the dataset
93
+ for episode in dataset:
94
+ for step in episode['steps']:
95
+ print(step)
data/preprocess_scripts/berkeley_cable_routing.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, quaternion_to_rotation_matrix, rotation_matrix_to_ortho6d
4
+
5
+
6
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
7
+ """
8
+ Convert terminate action to a boolean, where True means terminate.
9
+ """
10
+ return tf.equal(terminate_act, tf.constant(1.0, dtype=tf.float32))
11
+
12
+
13
+ def process_step(step: dict) -> dict:
14
+ """
15
+ Unify the action format and clean the task instruction.
16
+
17
+ DO NOT use python list, use tf.TensorArray instead.
18
+ """
19
+ # Convert raw action to our action
20
+ action = step['action']
21
+ action['terminate'] = terminate_act_to_bool(action['terminate_episode'])
22
+
23
+ eef_delta_pos = action['world_vector']
24
+ eef_ang=action['rotation_delta']
25
+
26
+ # No gripper_open found
27
+ # No base found
28
+
29
+ # Concatenate the action
30
+ arm_action=tf.concat([eef_delta_pos,eef_ang],axis=0)
31
+ action['arm_concat']=arm_action
32
+ #base_action = tf.concat([base_pos, base_ang], axis=0)
33
+ #action['base_concat'] = base_action
34
+
35
+ # Write the action format
36
+ action['format']=tf.constant("eef_vel_x,eef_vel_y,eef_vel_z,eef_angular_vel_roll,eef_angular_vel_pitch,eef_angular_vel_yaw")
37
+
38
+ # Convert raw state to our state
39
+ state = step['observation']
40
+ eef_pos = state['robot_state'][:3]
41
+ eef_ang = quaternion_to_rotation_matrix(state['robot_state'][3:])
42
+ eef_ang = rotation_matrix_to_ortho6d(eef_ang)
43
+
44
+ # Concatenate the state
45
+ state['arm_concat']=tf.concat([eef_pos,eef_ang],axis=0)
46
+
47
+ # Write the state format
48
+ state['format'] = tf.constant(
49
+ "eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
50
+
51
+ # Define the task instruction
52
+ step['observation']['natural_language_instruction'] = tf.constant(
53
+ "Route cable through the tight-fitting clip mounted on the table.")
54
+
55
+ return step
56
+
57
+
58
+ if __name__ == "__main__":
59
+ import tensorflow_datasets as tfds
60
+ from data.utils import dataset_to_path
61
+
62
+ DATASET_DIR = 'data/datasets/openx_embod/'
63
+ DATASET_NAME = 'berkeley_cable_routing'
64
+ # Load the dataset
65
+ dataset = tfds.builder_from_directory(
66
+ builder_dir=dataset_to_path(
67
+ DATASET_NAME, DATASET_DIR))
68
+ dataset = dataset.as_dataset(split='all')
69
+
70
+ # Inspect the dataset
71
+ for episode in dataset:
72
+ for step in episode['steps']:
73
+ print(step)
data/preprocess_scripts/berkeley_gnm_sac_son.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, euler_to_quaternion, euler_to_rotation_matrix, \
4
+ rotation_matrix_to_ortho6d
5
+
6
+ def process_step(step: dict) -> dict:
7
+ """
8
+ Unify the action format and clean the task instruction.
9
+
10
+ DO NOT use python list, use tf.TensorArray instead.
11
+ """
12
+ # Convert raw action to our action
13
+
14
+ origin_action = step['action']
15
+ step['action']={}
16
+ action=step['action']
17
+ action['terminate'] = step['is_terminal']
18
+
19
+ eef_pos=tf.cast(origin_action, dtype=tf.float32)
20
+ eef_ang=tf.cast(step['action_angle'][2:3], dtype=tf.float32)
21
+ eef_ang = euler_to_quaternion(tf.stack([0,0,eef_ang[0]], axis=0))
22
+ # No base found
23
+
24
+ # Concatenate the action
25
+ action['arm_concat'] = tf.concat([eef_pos,eef_ang],axis=0)
26
+
27
+ # Write the action format
28
+ action['format'] = tf.constant(
29
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w")
30
+
31
+ # Convert raw state to our state
32
+ state = step['observation']
33
+ # Concatenate the state
34
+ eef_pos=tf.cast(state['position'],dtype=tf.float32)
35
+ eef_ang=tf.cast(state['yaw'],dtype=tf.float32)
36
+ eef_ang = euler_to_rotation_matrix(tf.stack([0,0,eef_ang[0]],axis=0))
37
+ eef_ang = rotation_matrix_to_ortho6d(eef_ang)
38
+ state['arm_concat'] = tf.concat([eef_pos/100,eef_ang],axis=0)
39
+ # Write the state format
40
+ state['format'] = tf.constant(
41
+ "eef_pos_x,eef_pos_y,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
42
+
43
+ # Clean the task instruction
44
+ # Define the replacements (old, new) as a dictionary
45
+ replacements = {
46
+ '_': ' ',
47
+ '1f': ' ',
48
+ '4f': ' ',
49
+ '-': ' ',
50
+ '50': ' ',
51
+ '55': ' ',
52
+ '56': ' ',
53
+
54
+ }
55
+ instr = step['language_instruction']
56
+ instr = clean_task_instruction(instr, replacements)
57
+ step['observation']['natural_language_instruction'] = instr
58
+
59
+ return step
60
+
61
+
62
+ if __name__ == "__main__":
63
+ import tensorflow_datasets as tfds
64
+ from data.utils import dataset_to_path
65
+
66
+ DATASET_DIR = 'data/datasets/openx_embod'
67
+ DATASET_NAME = 'berkeley_gnm_sac_son'
68
+ # Load the dataset
69
+ dataset = tfds.builder_from_directory(
70
+ builder_dir=dataset_to_path(
71
+ DATASET_NAME, DATASET_DIR))
72
+ dataset = dataset.as_dataset(split='all')
73
+
74
+ # Inspect the dataset
75
+ for episode in dataset:
76
+ for step in episode['steps']:
77
+ print(step['action'][6:7])
78
+
data/preprocess_scripts/berkeley_rpt_converted_externally_to_rlds.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, quaternion_to_euler
4
+
5
+
6
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
7
+ """
8
+ Convert terminate action to a boolean, where True means terminate.
9
+ """
10
+ return tf.reduce_all(tf.equal(terminate_act, tf.constant([1, 0, 0], dtype=tf.int32)))
11
+
12
+
13
+ def process_step(step: dict) -> dict:
14
+ """
15
+ Unify the action format and clean the task instruction.
16
+
17
+ DO NOT use python list, use tf.TensorArray instead.
18
+ """
19
+ # Convert raw action to our action
20
+ action = step['action']
21
+ # Robot action, consists of [7 delta joint pos,1x gripper binary state].
22
+ delta_joint_pos = action[:7]
23
+ grip_open = tf.expand_dims(1 - action[7], axis=0)
24
+
25
+ # Concatenate the action
26
+ # action['arm_concat'] = tf.concat([eef_delta_pos, eef_ang, grip_open], axis=0)
27
+ step['action'] = {}
28
+ action = step['action']
29
+ action['arm_concat'] = tf.concat([delta_joint_pos, grip_open], axis=0)
30
+ action['terminate'] = step['is_terminal']
31
+
32
+ # Write the action format
33
+ action['format'] = tf.constant(
34
+ "arm_joint_0_delta_pos,arm_joint_1_delta_pos,arm_joint_2_delta_pos,arm_joint_3_delta_pos,arm_joint_4_delta_pos,arm_joint_5_delta_pos,arm_joint_6_delta_pos,gripper_open")
35
+
36
+ # Convert raw state to our state
37
+ state = step['observation']
38
+ # xArm joint positions (7 DoF).
39
+ arm_joint_pos = state['joint_pos']
40
+ # Binary gripper state (1 - closed, 0 - open)
41
+ grip_open = tf.expand_dims(1 - tf.cast(state['gripper'],dtype=tf.float32), axis=0)
42
+
43
+ # Concatenate the state
44
+ state['arm_concat'] = tf.concat([arm_joint_pos, grip_open], axis=0)
45
+
46
+ # Write the state format
47
+ state['format'] = tf.constant(
48
+ "arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,gripper_open")
49
+
50
+ # Clean the task instruction
51
+ # Define the replacements (old, new) as a dictionary
52
+ replacements = {
53
+ '_': ' ',
54
+ '1f': ' ',
55
+ '4f': ' ',
56
+ '-': ' ',
57
+ '50': ' ',
58
+ '55': ' ',
59
+ '56': ' ',
60
+
61
+ }
62
+ instr = step['language_instruction']
63
+ instr = clean_task_instruction(instr, replacements)
64
+ step['observation']['natural_language_instruction'] = instr
65
+
66
+ return step
67
+
68
+
69
+ if __name__ == "__main__":
70
+ import tensorflow_datasets as tfds
71
+ from data.utils import dataset_to_path
72
+
73
+ DATASET_DIR = 'data/datasets/openx_embod'
74
+ DATASET_NAME = 'fractal20220817_data'
75
+ # Load the dataset
76
+ dataset = tfds.builder_from_directory(
77
+ builder_dir=dataset_to_path(
78
+ DATASET_NAME, DATASET_DIR))
79
+ dataset = dataset.as_dataset(split='all')
80
+
81
+ # Inspect the dataset
82
+ for episode in dataset:
83
+ for step in episode['steps']:
84
+ print(step)
data/preprocess_scripts/calvin.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from data.utils import clean_task_instruction, euler_to_rotation_matrix, rotation_matrix_to_ortho6d
3
+ import tensorflow as tf
4
+ import os
5
+ import fnmatch
6
+ import random
7
+
8
+
9
+ def _parse_function(proto):
10
+ keys_to_features = {
11
+ 'action': tf.io.FixedLenFeature([], tf.string),
12
+ 'robot_obs': tf.io.FixedLenFeature([], tf.string),
13
+ 'rgb_static': tf.io.FixedLenFeature([], tf.string),
14
+ 'rgb_gripper': tf.io.FixedLenFeature([], tf.string),
15
+ 'terminate_episode': tf.io.FixedLenFeature([], tf.int64),
16
+ 'instruction': tf.io.FixedLenFeature([], tf.string),
17
+ }
18
+
19
+ parsed_features = tf.io.parse_single_example(proto, keys_to_features)
20
+
21
+ action = tf.io.parse_tensor(parsed_features['action'], out_type=tf.float64)
22
+ robot_obs = tf.io.parse_tensor(parsed_features['robot_obs'], out_type=tf.float64)
23
+ rgb_static = tf.io.parse_tensor(parsed_features['rgb_static'], out_type=tf.uint8)
24
+ rgb_gripper = tf.io.parse_tensor(parsed_features['rgb_gripper'], out_type=tf.uint8)
25
+ instruction = parsed_features['instruction']
26
+ terminate_episode = tf.cast(parsed_features['terminate_episode'], tf.int64)
27
+
28
+ action = tf.reshape(action, [7])
29
+ action = tf.cast(action, tf.float32)
30
+ robot_obs = tf.reshape(robot_obs, [15])
31
+ robot_obs = tf.cast(robot_obs, tf.float32)
32
+ rgb_static = tf.reshape(rgb_static, [200, 200, 3])
33
+ rgb_gripper = tf.reshape(rgb_gripper, [84, 84, 3])
34
+ # RGB to BGR
35
+ # rgb_static = rgb_static[:, :, ::-1]
36
+ # rgb_gripper = rgb_gripper[:, :, ::-1]
37
+
38
+ return {
39
+ 'action': action,
40
+ 'observation':{
41
+ 'robot_obs': robot_obs,
42
+ 'rgb_static': rgb_static,
43
+ 'rgb_gripper': rgb_gripper,
44
+ },
45
+ 'instruction': instruction,
46
+ 'terminate_episode': terminate_episode
47
+ }
48
+
49
+
50
+ def dataset_generator_from_tfrecords(seed):
51
+ tfrecord_path = './data/datasets/calvin/tfrecords/'
52
+ filepaths = []
53
+ for root, dirs, files in os.walk(tfrecord_path):
54
+ for filename in fnmatch.filter(files, '*.tfrecord'):
55
+ filepath = os.path.join(root, filename)
56
+ filepaths.append(filepath)
57
+
58
+ random.seed(seed)
59
+ random.shuffle(filepaths)
60
+ for filepath in filepaths:
61
+ raw_dataset = tf.data.TFRecordDataset(filepath)
62
+ dataset = raw_dataset.map(_parse_function)
63
+ yield {
64
+ 'steps': dataset
65
+ }
66
+
67
+
68
+ def load_dataset(seed):
69
+ dataset = tf.data.Dataset.from_generator(
70
+ lambda: dataset_generator_from_tfrecords(seed),
71
+ output_signature={
72
+ 'steps': tf.data.DatasetSpec(
73
+ element_spec={
74
+ 'action': tf.TensorSpec(shape=(7,), dtype=tf.float32),
75
+ 'observation':{
76
+ 'robot_obs': tf.TensorSpec(shape=(15,), dtype=tf.float32),
77
+ 'rgb_static': tf.TensorSpec(shape=(200,200,3), dtype=tf.uint8),
78
+ 'rgb_gripper': tf.TensorSpec(shape=(84,84,3), dtype=tf.uint8),
79
+ },
80
+ 'instruction': tf.TensorSpec(shape=(), dtype=tf.string),
81
+ 'terminate_episode': tf.TensorSpec(shape=(), dtype=tf.int64),
82
+ }
83
+ )
84
+ }
85
+ )
86
+
87
+ return dataset
88
+
89
+
90
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
91
+ """
92
+ Convert terminate action to a boolean, where True means terminate.
93
+ """
94
+ return tf.where(
95
+ tf.equal(terminate_act, tf.constant(0, dtype=tf.int64)),
96
+ tf.constant(False),tf.constant(True))
97
+
98
+
99
+ def process_step(step: dict) -> dict:
100
+ """
101
+ Unify the action format and clean the task instruction.
102
+
103
+ DO NOT use python list, use tf.TensorArray instead.
104
+ """
105
+ # Convert raw action to our action
106
+ old_action = step['action']
107
+ step['action'] = {}
108
+ action = step['action']
109
+ step['action']['terminate'] = terminate_act_to_bool(step['terminate_episode'])
110
+ # ['actions']
111
+ # (dtype=np.float32, shape=(7,))
112
+ # tcp position (3): x,y,z in absolute world coordinates
113
+ # tcp orientation (3): euler angles x,y,z in absolute world coordinates
114
+ # gripper_action (1): binary (close = -1, open = 1)
115
+ eef_pos = old_action[:3]
116
+ eef_ang = euler_to_rotation_matrix(old_action[3:6])
117
+ eef_ang = rotation_matrix_to_ortho6d(eef_ang)
118
+ gripper_open = (old_action[6] + 1) / 2
119
+ gripper_open = tf.expand_dims(gripper_open, axis=0)
120
+
121
+ # # No base found
122
+ arm_action = tf.concat([eef_pos, eef_ang, gripper_open], axis=0)
123
+ action['arm_concat'] = arm_action
124
+ # # Write the action format
125
+ action['format'] = tf.constant(
126
+ "eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5,gripper_open")
127
+
128
+ state = step['observation']
129
+ # ['robot_obs']
130
+ # (dtype=np.float32, shape=(15,))
131
+ # tcp position (3): x,y,z in world coordinates
132
+ # tcp orientation (3): euler angles x,y,z in world coordinates
133
+ # gripper opening width (1): in meter
134
+ # arm_joint_states (7): in rad
135
+ # gripper_action (1): binary (close = -1, open = 1)
136
+ eef_pos = state['robot_obs'][:3]
137
+ eef_ang = euler_to_rotation_matrix(state['robot_obs'][3:6])
138
+ eef_ang = rotation_matrix_to_ortho6d(eef_ang)
139
+ gripper_open = (state['robot_obs'][14] + 1) / 2
140
+ gripper_open = tf.expand_dims(gripper_open, axis=0)
141
+ qpos = state['robot_obs'][7:14]
142
+
143
+ state['arm_concat'] = tf.concat([qpos,gripper_open,eef_pos,eef_ang], axis=0)
144
+ # # Write the state format
145
+ state['format'] = tf.constant(
146
+ "arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,gripper_open,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
147
+
148
+ # Clean the task instruction
149
+ # Define the replacements (old, new) as a dictionary
150
+ replacements = {
151
+ '_': ' ',
152
+ '1f': ' ',
153
+ '4f': ' ',
154
+ '-': ' ',
155
+ '50': ' ',
156
+ '55': ' ',
157
+ '56': ' ',
158
+
159
+ }
160
+ instr = step['instruction']
161
+ instr= clean_task_instruction(instr, replacements)
162
+ step['observation']['natural_language_instruction'] = instr
163
+
164
+ return step
165
+
166
+
167
+ if __name__ == "__main__":
168
+ import tensorflow_datasets as tfds
169
+ from data.utils import dataset_to_path
170
+
171
+ # Load the dataset
172
+ dataset = load_dataset(1717055919)
173
+ for data in dataset.take(1):
174
+ for step in data['steps']:
175
+ step = process_step(step)
176
+ print(step['observation']['natural_language_instruction'])
data/preprocess_scripts/cmu_franka_exploration_dataset_converted_externally_to_rlds.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, quaternion_to_euler,euler_to_quaternion
4
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
5
+ """
6
+ Convert terminate action to a boolean, where True means terminate.
7
+ """
8
+ return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
9
+
10
+ def process_step(step: dict) -> dict:
11
+ """
12
+ Unify the action format and clean the task instruction.
13
+
14
+ DO NOT use python list, use tf.TensorArray instead.
15
+ """
16
+ # Convert raw action to our action
17
+
18
+ origin_action = step['action']
19
+ step['action']={}
20
+ action=step['action']
21
+ action['terminate']=terminate_act_to_bool(origin_action[7])
22
+
23
+ # gripper_open: 1-open, 0-closed
24
+
25
+ eef_pos=origin_action[:3]
26
+ eef_ang=origin_action[3:6]
27
+ eef_ang = euler_to_quaternion(eef_ang)
28
+ grip_open=origin_action[6:7]
29
+ # No base found
30
+
31
+ # Concatenate the action
32
+ action['arm_concat'] = tf.concat([eef_pos,eef_ang,grip_open],axis=0)
33
+
34
+ # Write the action format
35
+ action['format'] = tf.constant(
36
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
37
+
38
+ # No state found
39
+
40
+ # Clean the task instruction
41
+ # Define the replacements (old, new) as a dictionary
42
+ replacements = {
43
+ '_': ' ',
44
+ '1f': ' ',
45
+ '4f': ' ',
46
+ '-': ' ',
47
+ '50': ' ',
48
+ '55': ' ',
49
+ '56': ' ',
50
+
51
+ }
52
+ instr = step['language_instruction']
53
+ instr = clean_task_instruction(instr, replacements)
54
+ step['observation']['natural_language_instruction'] = instr
55
+
56
+ return step
57
+
58
+
59
+ if __name__ == "__main__":
60
+ import tensorflow_datasets as tfds
61
+ from data.utils import dataset_to_path
62
+
63
+ DATASET_DIR = 'data/datasets/openx_embod'
64
+ DATASET_NAME = 'cmu_franka_exploration_dataset_converted_externally_to_rlds'
65
+ # Load the dataset
66
+ dataset = tfds.builder_from_directory(
67
+ builder_dir=dataset_to_path(
68
+ DATASET_NAME, DATASET_DIR))
69
+ dataset = dataset.as_dataset(split='all')
70
+
71
+ # Inspect the dataset
72
+ for episode in dataset:
73
+ for step in episode['steps']:
74
+ print(step['action'][6:7])
75
+
data/preprocess_scripts/cmu_play_fusion.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, quaternion_to_euler
4
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
5
+ """
6
+ Convert terminate action to a boolean, where True means terminate.
7
+ """
8
+ return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
9
+
10
+ def process_step(step: dict) -> dict:
11
+ """
12
+ Unify the action format and clean the task instruction.
13
+
14
+ DO NOT use python list, use tf.TensorArray instead.
15
+ """
16
+ # Convert raw action to our action
17
+
18
+ origin_action = step['action']
19
+ step['action']={}
20
+ action=step['action']
21
+ action['terminate']=terminate_act_to_bool(origin_action[8])
22
+
23
+
24
+ eef_pos=origin_action[:3]
25
+ # eef_ang=quaternion_to_euler(origin_action[3:7])
26
+ eef_ang = origin_action[3:7]
27
+ grip_open=origin_action[7:8]
28
+ # No base found
29
+
30
+ # Concatenate the action
31
+ action['arm_concat'] = tf.concat([eef_pos,eef_ang,grip_open],axis=0)
32
+
33
+ # Write the action format
34
+ action['format'] = tf.constant(
35
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
36
+
37
+ # Convert raw state to our state
38
+ state = step['observation']
39
+ # Concatenate the state
40
+ arm_joint_ang=state['state'][:7]
41
+ grip_open=state['state'][7:8] * 11.765 # rescale to [0, 1]
42
+ state['arm_concat'] = tf.concat([arm_joint_ang,grip_open],axis=0)
43
+ # Write the state format
44
+ state['format'] = tf.constant(
45
+ "arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,gripper_joint_0_pos")
46
+
47
+ # Clean the task instruction
48
+ # Define the replacements (old, new) as a dictionary
49
+ replacements = {
50
+ '_': ' ',
51
+ '1f': ' ',
52
+ '4f': ' ',
53
+ '-': ' ',
54
+ '50': ' ',
55
+ '55': ' ',
56
+ '56': ' ',
57
+
58
+ }
59
+ instr = step['language_instruction']
60
+ instr = clean_task_instruction(instr, replacements)
61
+ step['observation']['natural_language_instruction'] = instr
62
+
63
+ return step
64
+
65
+
66
+ if __name__ == "__main__":
67
+ import tensorflow_datasets as tfds
68
+ from data.utils import dataset_to_path
69
+
70
+ DATASET_DIR = 'data/datasets/openx_embod'
71
+ DATASET_NAME = 'cmu_play_fusion'
72
+ # Load the dataset
73
+ dataset = tfds.builder_from_directory(
74
+ builder_dir=dataset_to_path(
75
+ DATASET_NAME, DATASET_DIR))
76
+ dataset = dataset.as_dataset(split='all')
77
+
78
+ # Inspect the dataset
79
+ for episode in dataset:
80
+ for step in episode['steps']:
81
+ print(step['action'][6:7])
82
+
data/preprocess_scripts/cmu_stretch.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, quaternion_to_euler,euler_to_quaternion
4
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
5
+ """
6
+ Convert terminate action to a boolean, where True means terminate.
7
+ """
8
+ return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
9
+
10
+ def process_step(step: dict) -> dict:
11
+ """
12
+ Unify the action format and clean the task instruction.
13
+
14
+ DO NOT use python list, use tf.TensorArray instead.
15
+ """
16
+ # Convert raw action to our action
17
+
18
+ origin_action = step['action']
19
+ step['action']={}
20
+ action=step['action']
21
+ action['terminate']=terminate_act_to_bool(origin_action[7])
22
+
23
+
24
+ eef_pos=origin_action[:3]
25
+ eef_ang=origin_action[3:6]
26
+ eef_ang = euler_to_quaternion(eef_ang)
27
+ grip_open=origin_action[6:7]
28
+ # No base found
29
+
30
+ # Concatenate the action
31
+ action['arm_concat'] = tf.concat([eef_pos,eef_ang,grip_open],axis=0)
32
+
33
+ # Write the action format
34
+ action['format'] = tf.constant(
35
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
36
+
37
+ # Convert raw state to our state
38
+ state = step['observation']
39
+ # Concatenate the state
40
+ eef_pos_x = state['state'][0:1]
41
+ eef_pos_z = state['state'][2:3]
42
+ grip_open = state['state'][3:4]
43
+ state['arm_concat'] = tf.concat(
44
+ [eef_pos_x, eef_pos_z, grip_open], axis=0)
45
+ # Write the state format
46
+ state['format'] = tf.constant(
47
+ "eef_pos_x,eef_pos_z,gripper_open")
48
+
49
+ # Clean the task instruction
50
+ # Define the replacements (old, new) as a dictionary
51
+ replacements = {
52
+ '_': ' ',
53
+ '1f': ' ',
54
+ '4f': ' ',
55
+ '-': ' ',
56
+ '50': ' ',
57
+ '55': ' ',
58
+ '56': ' ',
59
+
60
+ }
61
+ instr = step['language_instruction']
62
+ instr = clean_task_instruction(instr, replacements)
63
+ step['observation']['natural_language_instruction'] = instr
64
+
65
+ return step
66
+
67
+
68
+ if __name__ == "__main__":
69
+ import tensorflow_datasets as tfds
70
+ from data.utils import dataset_to_path
71
+
72
+ DATASET_DIR = 'data/datasets/openx_embod'
73
+ DATASET_NAME = 'cmu_stretch'
74
+ # Load the dataset
75
+ dataset = tfds.builder_from_directory(
76
+ builder_dir=dataset_to_path(
77
+ DATASET_NAME, DATASET_DIR))
78
+ dataset = dataset.as_dataset(split='all')
79
+
80
+ # Inspect the dataset
81
+ for episode in dataset:
82
+ for step in episode['steps']:
83
+ print(step['action'][6:7])
84
+
data/preprocess_scripts/droid.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, euler_to_rotation_matrix, rotation_matrix_to_ortho6d
4
+
5
+
6
+ def process_step(step: dict) -> dict:
7
+ """
8
+ Unify the action format and clean the task instruction.
9
+
10
+ DO NOT use python list, use tf.TensorArray instead.
11
+ """
12
+ # Convert raw action to our action
13
+ action_dict = step['action_dict']
14
+
15
+ # Robot action
16
+ eef_pos = action_dict['cartesian_position'][:3]
17
+ eef_ang = action_dict['cartesian_position'][3:6]
18
+ eef_ang = euler_to_rotation_matrix(eef_ang)
19
+ eef_ang = rotation_matrix_to_ortho6d(eef_ang)
20
+ eef_pos_vel = action_dict['cartesian_velocity'][:3]
21
+ eef_ang_vel = action_dict['cartesian_velocity'][3:6]
22
+ joint_pos = action_dict['joint_position']
23
+ joint_vel = action_dict['joint_velocity']
24
+ grip_pos = action_dict['gripper_position']
25
+ grip_vel = action_dict['gripper_velocity']
26
+
27
+ # Concatenate the action
28
+ step['action'] = {}
29
+ action = step['action']
30
+
31
+ arm_action = tf.concat([eef_pos, eef_ang, eef_pos_vel, eef_ang_vel, joint_pos, joint_vel, grip_pos, grip_vel], axis=0)
32
+ action['arm_concat'] = arm_action
33
+ action['terminate'] = step['is_terminal']
34
+
35
+ # Write the action format
36
+ action['format'] = tf.constant(
37
+ "eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5,eef_vel_x,eef_vel_y,eef_vel_z,eef_angular_vel_roll,eef_angular_vel_pitch,eef_angular_vel_yaw,arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,arm_joint_0_vel,arm_joint_1_vel,arm_joint_2_vel,arm_joint_3_vel,arm_joint_4_vel,arm_joint_5_vel,arm_joint_6_vel,gripper_joint_0_pos,gripper_joint_0_vel")
38
+
39
+ # Convert raw state to our state
40
+ # Robot state
41
+ state = step['observation']
42
+ eef_pos = state['cartesian_position'][:3]
43
+ eef_ang = state['cartesian_position'][3:6]
44
+ eef_ang = euler_to_rotation_matrix(eef_ang)
45
+ eef_ang = rotation_matrix_to_ortho6d(eef_ang)
46
+ joint_pos = state['joint_position']
47
+ grip_pos = 1 - state['gripper_position']
48
+
49
+ # Concatenate the state
50
+ state['arm_concat'] = tf.concat([
51
+ joint_pos,grip_pos,eef_pos,eef_ang], axis=0)
52
+
53
+
54
+ # Write the state format
55
+ state['format'] = tf.constant(
56
+ "arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,gripper_joint_0_pos,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
57
+
58
+ # Clean the task instruction
59
+ # Define the replacements (old, new) as a dictionary
60
+ replacements = {
61
+ '_': ' ',
62
+ '1f': ' ',
63
+ '4f': ' ',
64
+ '-': ' ',
65
+ '50': ' ',
66
+ '55': ' ',
67
+ '56': ' ',
68
+
69
+ }
70
+ instr = step['language_instruction']
71
+ instr = clean_task_instruction(instr, replacements)
72
+ step['observation']['natural_language_instruction'] = instr
73
+
74
+ return step
75
+
76
+
77
+ if __name__ == "__main__":
78
+ pass
data/preprocess_scripts/fractal20220817_data.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, euler_to_quaternion, quaternion_to_rotation_matrix,\
4
+ rotation_matrix_to_ortho6d
5
+
6
+
7
+
8
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
9
+ """
10
+ Convert terminate action to a boolean, where True means terminate.
11
+ """
12
+ return tf.reduce_all(tf.equal(terminate_act, tf.constant([1, 0, 0], dtype=tf.int32)))
13
+
14
+
15
+ def process_step(step: dict) -> dict:
16
+ """
17
+ Unify the action format and clean the task instruction.
18
+
19
+ DO NOT use python list, use tf.TensorArray instead.
20
+ """
21
+ # Convert raw action to our action
22
+ action = step['action']
23
+ action['terminate'] = terminate_act_to_bool(action['terminate_episode'])
24
+
25
+ eef_delta_pos = action['world_vector']
26
+ eef_ang = action['rotation_delta']
27
+ eef_ang = euler_to_quaternion(eef_ang)
28
+ grip_open = 1 - (action['gripper_closedness_action'] + 1) / 2
29
+ # Multiplied by 3 Hz to get units m/s and rad/s
30
+ base_delta_pos = action['base_displacement_vector'] * 3
31
+ base_delta_ang = action['base_displacement_vertical_rotation'] * 3
32
+
33
+ # Concatenate the action
34
+ arm_action = tf.concat([eef_delta_pos, eef_ang, grip_open], axis=0)
35
+ action['arm_concat'] = arm_action
36
+ base_action = tf.concat([base_delta_pos, base_delta_ang], axis=0)
37
+ action['base_concat'] = base_action
38
+
39
+ # Write the action format
40
+ action['format'] = tf.constant(
41
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open,base_vel_x,base_vel_y,base_angular_vel")
42
+
43
+ # Convert raw state to our state
44
+ state = step['observation']
45
+ eef_pos = state['base_pose_tool_reached'][:3]
46
+ # eef_ang = quaternion_to_euler(state['base_pose_tool_reached'][3:])
47
+ eef_ang = quaternion_to_rotation_matrix(state['base_pose_tool_reached'][3:])
48
+ eef_ang = rotation_matrix_to_ortho6d(eef_ang)
49
+ grip_open = 1 - state['gripper_closed']
50
+
51
+ # Concatenate the state
52
+ state['arm_concat'] = tf.concat([eef_pos, eef_ang, grip_open], axis=0)
53
+
54
+ # Write the state format
55
+ state['format'] = tf.constant(
56
+ "eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5,gripper_open")
57
+
58
+ # Clean the task instruction
59
+ # Define the replacements (old, new) as a dictionary
60
+ replacements = {
61
+ '_': ' ',
62
+ '1f': ' ',
63
+ '4f': ' ',
64
+ '-': ' ',
65
+ '50': ' ',
66
+ '55': ' ',
67
+ '56': ' ',
68
+
69
+ }
70
+ instr = step['observation']['natural_language_instruction']
71
+ instr = clean_task_instruction(instr, replacements)
72
+ step['observation']['natural_language_instruction'] = instr
73
+
74
+ return step
75
+
76
+
77
+ if __name__ == "__main__":
78
+ import tensorflow_datasets as tfds
79
+ from data.utils import dataset_to_path
80
+
81
+ DATASET_DIR = 'data/datasets/openx_embod'
82
+ DATASET_NAME = 'fractal20220817_data'
83
+ # Load the dataset
84
+ dataset = tfds.builder_from_directory(
85
+ builder_dir=dataset_to_path(
86
+ DATASET_NAME, DATASET_DIR))
87
+ dataset = dataset.as_dataset(split='all')
88
+
89
+ # Inspect the dataset
90
+ for episode in dataset:
91
+ for step in episode['steps']:
92
+ print(step)
data/preprocess_scripts/iamlab_cmu_pickup_insert_converted_externally_to_rlds.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction,quaternion_to_euler
4
+
5
+ def process_step(step: dict) -> dict:
6
+ """
7
+ Unify the action format and clean the task instruction.
8
+
9
+ DO NOT use python list, use tf.TensorArray instead.
10
+ """
11
+ # Convert raw action to our action
12
+
13
+ origin_action = step['action']
14
+ step['action']={}
15
+ action=step['action']
16
+
17
+ eef_delta_pos = origin_action[:3]
18
+ # delta ZYX euler angles
19
+ # eef_ang=quaternion_to_euler(origin_action[3:7])
20
+ eef_ang = origin_action[3:7]
21
+ grip_open=origin_action[7:8]
22
+
23
+ # No base found
24
+
25
+ # Concatenate the action
26
+ action['arm_concat'] = tf.concat([eef_delta_pos,eef_ang,grip_open],axis=0)
27
+
28
+ # Write the action format
29
+ action['format'] = tf.constant(
30
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
31
+
32
+ # Convert raw state to our state
33
+ state = step['observation']
34
+ # Concatenate the state
35
+ # 7x robot joint angles, 1x gripper status, 6x joint torques, 6x end-effector force
36
+ arm_joint_ang=state['state'][:7]
37
+
38
+ grip_open=state['state'][7:8]
39
+
40
+ state['arm_concat'] = tf.concat([arm_joint_ang,grip_open],axis=0)
41
+
42
+ # Write the state format
43
+ state['format'] = tf.constant(
44
+ "arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,gripper_open")
45
+
46
+ # Clean the task instruction
47
+ # Define the replacements (old, new) as a dictionary
48
+ replacements = {
49
+ '_': ' ',
50
+ '1f': ' ',
51
+ '4f': ' ',
52
+ '-': ' ',
53
+ '50': ' ',
54
+ '55': ' ',
55
+ '56': ' ',
56
+
57
+ }
58
+ instr = step['language_instruction']
59
+ instr = clean_task_instruction(instr, replacements)
60
+ step['observation']['natural_language_instruction'] = instr
61
+
62
+ return step
63
+
64
+
65
+ if __name__ == "__main__":
66
+ import tensorflow_datasets as tfds
67
+ from data.utils import dataset_to_path
68
+
69
+ DATASET_DIR = 'data/datasets/openx_embod'
70
+ DATASET_NAME = 'iamlab_cmu_pickup_insert_converted_externally_to_rlds'
71
+ # Load the dataset
72
+ dataset = tfds.builder_from_directory(
73
+ builder_dir=dataset_to_path(
74
+ DATASET_NAME, DATASET_DIR))
75
+ dataset = dataset.as_dataset(split='all')
76
+
77
+ # Inspect the dataset
78
+ for episode in dataset:
79
+ for step in episode['steps']:
80
+ print(step)
data/preprocess_scripts/libero_goal_no_noops.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, euler_to_rotation_matrix, rotation_matrix_to_ortho6d
4
+
5
+
6
+ def process_step(step: dict) -> dict:
7
+ """
8
+ Unify the action format and clean the task instruction.
9
+
10
+ DO NOT use python list, use tf.TensorArray instead.
11
+ """
12
+ # Convert raw action to our action
13
+ action_dict = step['action']
14
+
15
+ # Robot action
16
+ # eef_pos = action_dict['ee_pos'][:3]
17
+ # eef_ang = action_dict['ee_pos'][3:6]
18
+ # eef_ang = euler_to_rotation_matrix(eef_ang)
19
+ # eef_ang = rotation_matrix_to_ortho6d(eef_ang)
20
+ eef_pos_vel = action_dict[:3]
21
+ eef_ang_vel = action_dict[3:6]
22
+ # joint_pos = action_dict['joint_pos'][:-1]
23
+ # joint_vel = action_dict['delta_joint'][:-1]
24
+ grip_pos = 1 - tf.clip_by_value(action_dict[-1:], 0, 1)
25
+
26
+ # grip_vel = action_dict['gripper_velocity']
27
+
28
+ # Concatenate the action
29
+ step['action'] = {}
30
+ action = step['action']
31
+
32
+ arm_action = tf.concat([eef_pos_vel, eef_ang_vel, grip_pos], axis=0)
33
+ action['arm_concat'] = arm_action
34
+ # action['terminate'] = step['is_terminal']
35
+
36
+ # Write the action format
37
+ action['format'] = tf.constant(
38
+ "eef_vel_x,eef_vel_y,eef_vel_z,eef_angular_vel_roll,eef_angular_vel_pitch,eef_angular_vel_yaw,gripper_joint_0_pos")
39
+
40
+ # Convert raw state to our state
41
+ # Robot state
42
+ state = step['observation']
43
+ # print(state.keys())
44
+ # image = step['observation']['image']
45
+ eef_pos = state['state'][:3]
46
+ eef_ang = state['state'][3:6]
47
+ eef_ang = euler_to_rotation_matrix(eef_ang)
48
+ eef_ang = rotation_matrix_to_ortho6d(eef_ang)
49
+ # joint_pos = state['joint_pos'][:-1]
50
+ grip_pos = state['state'][-2:]
51
+
52
+ # Concatenate the state
53
+ state['arm_concat'] = tf.concat([
54
+ grip_pos,eef_pos,eef_ang], axis=0)
55
+
56
+
57
+ # Write the state format
58
+ state['format'] = tf.constant(
59
+ "gripper_joint_0_pos,gripper_joint_1_pos,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
60
+
61
+ # Clean the task instruction
62
+ # Define the replacements (old, new) as a dictionary
63
+ replacements = {
64
+ '_': ' ',
65
+ '1f': ' ',
66
+ '4f': ' ',
67
+ '-': ' ',
68
+ '50': ' ',
69
+ '55': ' ',
70
+ '56': ' ',
71
+
72
+ }
73
+ instr = step['language_instruction']
74
+ # instr = clean_task_instruction(instr, replacements)
75
+ step['observation'] = state
76
+ step['observation']['natural_language_instruction'] = instr
77
+
78
+ return step
79
+
80
+
81
+ if __name__ == "__main__":
82
+ pass
data/preprocess_scripts/libero_spatial_no_noops.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, euler_to_rotation_matrix, rotation_matrix_to_ortho6d
4
+
5
+
6
+ def process_step(step: dict) -> dict:
7
+ """
8
+ Unify the action format and clean the task instruction.
9
+
10
+ DO NOT use python list, use tf.TensorArray instead.
11
+ """
12
+ # Convert raw action to our action
13
+ action_dict = step['action']
14
+
15
+ # Robot action
16
+ # eef_pos = action_dict['ee_pos'][:3]
17
+ # eef_ang = action_dict['ee_pos'][3:6]
18
+ # eef_ang = euler_to_rotation_matrix(eef_ang)
19
+ # eef_ang = rotation_matrix_to_ortho6d(eef_ang)
20
+ eef_pos_vel = action_dict[:3]
21
+ eef_ang_vel = action_dict[3:6]
22
+ # joint_pos = action_dict['joint_pos'][:-1]
23
+ # joint_vel = action_dict['delta_joint'][:-1]
24
+ grip_pos = 1 - tf.clip_by_value(action_dict[-1:], 0, 1)
25
+
26
+ # grip_vel = action_dict['gripper_velocity']
27
+
28
+ # Concatenate the action
29
+ step['action'] = {}
30
+ action = step['action']
31
+
32
+ arm_action = tf.concat([eef_pos_vel, eef_ang_vel, grip_pos], axis=0)
33
+ action['arm_concat'] = arm_action
34
+ # action['terminate'] = step['is_terminal']
35
+
36
+ # Write the action format
37
+ action['format'] = tf.constant(
38
+ "eef_vel_x,eef_vel_y,eef_vel_z,eef_angular_vel_roll,eef_angular_vel_pitch,eef_angular_vel_yaw,gripper_joint_0_pos")
39
+
40
+ # Convert raw state to our state
41
+ # Robot state
42
+ state = step['observation']
43
+ # print(state.keys())
44
+ # image = step['observation']['image']
45
+ eef_pos = state['state'][:3]
46
+ eef_ang = state['state'][3:6]
47
+ eef_ang = euler_to_rotation_matrix(eef_ang)
48
+ eef_ang = rotation_matrix_to_ortho6d(eef_ang)
49
+ # joint_pos = state['joint_pos'][:-1]
50
+ grip_pos = state['state'][-2:]
51
+
52
+ # Concatenate the state
53
+ state['arm_concat'] = tf.concat([
54
+ grip_pos,eef_pos,eef_ang], axis=0)
55
+
56
+
57
+ # Write the state format
58
+ state['format'] = tf.constant(
59
+ "gripper_joint_0_pos,gripper_joint_1_pos,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
60
+
61
+ # Clean the task instruction
62
+ # Define the replacements (old, new) as a dictionary
63
+ replacements = {
64
+ '_': ' ',
65
+ '1f': ' ',
66
+ '4f': ' ',
67
+ '-': ' ',
68
+ '50': ' ',
69
+ '55': ' ',
70
+ '56': ' ',
71
+
72
+ }
73
+ instr = step['language_instruction']
74
+ # instr = clean_task_instruction(instr, replacements)
75
+ step['observation'] = state
76
+ step['observation']['natural_language_instruction'] = instr
77
+
78
+ return step
79
+
80
+
81
+ if __name__ == "__main__":
82
+ pass
data/preprocess_scripts/nyu_rot_dataset_converted_externally_to_rlds.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, euler_to_quaternion, euler_to_rotation_matrix, \
4
+ rotation_matrix_to_ortho6d
5
+
6
+ def process_step(step: dict) -> dict:
7
+ """
8
+ Unify the action format and clean the task instruction.
9
+
10
+ DO NOT use python list, use tf.TensorArray instead.
11
+ """
12
+ # Convert raw action to our action
13
+
14
+ origin_action = step['action']
15
+ step['action']={}
16
+ action=step['action']
17
+ action['terminate'] = step['is_terminal']
18
+
19
+ eef_delta_pos = origin_action[:3]
20
+ eef_ang=origin_action[3:6]
21
+ eef_ang = euler_to_quaternion(eef_ang)
22
+ # gripper_open: 0-open, 1-closed
23
+ grip_open=tf.where(tf.equal(origin_action[6:],tf.constant(0.0)),tf.constant(1.0),tf.constant(0.0))
24
+
25
+ # No base found
26
+
27
+ # Concatenate the action
28
+ action['arm_concat'] = tf.concat([eef_delta_pos,eef_ang,grip_open],axis=0)
29
+
30
+ # Write the action format
31
+ action['format'] = tf.constant(
32
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
33
+
34
+ # Convert raw state to our state
35
+ state = step['observation']
36
+ eef_pos=state['state'][:3]
37
+ eef_ang=state['state'][3:6]
38
+ eef_ang = euler_to_rotation_matrix(eef_ang)
39
+ eef_ang = rotation_matrix_to_ortho6d(eef_ang)
40
+ grip_open=1-state['state'][6:7]
41
+ # Concatenate the state
42
+ state['arm_concat'] = tf.concat([eef_pos,eef_ang,grip_open],axis=0)
43
+
44
+ # Write the state format
45
+ state['format'] = tf.constant(
46
+ "eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5,gripper_open")
47
+
48
+ # Clean the task instruction
49
+ # Define the replacements (old, new) as a dictionary
50
+ replacements = {
51
+ '_': ' ',
52
+ '1f': ' ',
53
+ '4f': ' ',
54
+ '-': ' ',
55
+ '50': ' ',
56
+ '55': ' ',
57
+ '56': ' ',
58
+
59
+ }
60
+ instr = step['language_instruction']
61
+ instr = clean_task_instruction(instr, replacements)
62
+ step['observation']['natural_language_instruction'] = instr
63
+
64
+ return step
65
+
66
+
67
+ if __name__ == "__main__":
68
+ import tensorflow_datasets as tfds
69
+ from data.utils import dataset_to_path
70
+
71
+ DATASET_DIR = 'data/datasets/openx_embod'
72
+ DATASET_NAME = 'nyu_rot_dataset_converted_externally_to_rlds'
73
+ # Load the dataset
74
+ dataset = tfds.builder_from_directory(
75
+ builder_dir=dataset_to_path(
76
+ DATASET_NAME, DATASET_DIR))
77
+ dataset = dataset.as_dataset(split='all')
78
+
79
+ # Inspect the dataset
80
+ for episode in dataset:
81
+ for step in episode['steps']:
82
+ print(step)
data/preprocess_scripts/robo_net.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+
4
+ from data.utils import clean_task_instruction, euler_to_quaternion, euler_to_rotation_matrix, \
5
+ rotation_matrix_to_ortho6d
6
+
7
+
8
+ def process_step(step: dict) -> dict:
9
+ """
10
+ Unify the action format and clean the task instruction.
11
+
12
+ DO NOT use python list, use tf.TensorArray instead.
13
+ """
14
+ # Convert raw action to our action
15
+ action = step['action']
16
+ eef_delta_pos = action[:3]
17
+ eef_delta_angle_yaw = action[3:4]
18
+ eef_ang = tf.stack([0.0, 0.0, eef_delta_angle_yaw[0]], axis=0)
19
+ eef_ang = euler_to_quaternion(eef_ang)
20
+ eef_gripper_open = (1 - action[4:5]) / 2
21
+
22
+ step['action'] = {}
23
+ action = step['action']
24
+ action['terminate'] = step['is_terminal']
25
+
26
+ # No base found
27
+
28
+ # Concatenate the action
29
+ arm_action = tf.concat([eef_delta_pos, eef_ang, eef_gripper_open], axis=0)
30
+ action['arm_concat'] = arm_action
31
+
32
+ # Write the action format
33
+ action['format'] = tf.constant(
34
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
35
+
36
+ # Convert raw state to our state
37
+ state = step['observation']
38
+ eef_pos = state['state'][:3]
39
+ eef_ang_yaw = state['state'][3:4]
40
+ eef_ang = tf.stack([0.0, 0.0, eef_ang_yaw[0]], axis=0)
41
+ eef_ang = euler_to_rotation_matrix(eef_ang)
42
+ eef_ang = rotation_matrix_to_ortho6d(eef_ang)
43
+ grip_joint_pos = state['state'][4:5]
44
+ # If abs(grip_joint_pos) > 3.15, then convert it to the radian
45
+ grip_joint_pos = tf.cond(tf.greater(tf.abs(grip_joint_pos), 3.15),
46
+ lambda: grip_joint_pos / 180 * np.pi,
47
+ lambda: grip_joint_pos)
48
+ # Concatenate the state
49
+ state['arm_concat'] = tf.concat([eef_pos,eef_ang,grip_joint_pos],axis=0)
50
+
51
+ # Write the state format
52
+ state['format'] = tf.constant(
53
+ "eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5,gripper_open")
54
+
55
+ # Clean the task instruction
56
+ # Define the replacements (old, new) as a dictionary
57
+ replacements = {
58
+ '_': ' ',
59
+ '1f': ' ',
60
+ '4f': ' ',
61
+ '-': ' ',
62
+ '50': ' ',
63
+ '55': ' ',
64
+ '56': ' ',
65
+
66
+ }
67
+ instr = step['language_instruction']
68
+ instr = clean_task_instruction(instr, replacements)
69
+ step['observation']['natural_language_instruction'] = instr
70
+
71
+ return step
data/preprocess_scripts/robomimic_lift_ph.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import tensorflow_datasets as tfds
3
+ from data.utils import clean_task_instruction, quaternion_to_euler
4
+
5
+
6
+ def load_dataset():
7
+ builder = tfds.builder('robomimic_ph/lift_ph_image')
8
+ builder.download_and_prepare()
9
+ ds = builder.as_dataset(split='train', shuffle_files=True)
10
+ return ds
11
+
12
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
13
+ """
14
+ Convert terminate action to a boolean, where True means terminate.
15
+ """
16
+ return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
17
+
18
+ def process_step(step: dict) -> dict:
19
+ """
20
+ Unify the action format and clean the task instruction.
21
+
22
+ DO NOT use python list, use tf.TensorArray instead.
23
+ """
24
+ # format refers to https://www.tensorflow.org/datasets/catalog/robomimic_mg
25
+ # Convert raw action to our action
26
+ eef = step['action']
27
+ step['action'] = {}
28
+ action = step['action']
29
+ action['terminate'] = step['is_terminal']
30
+
31
+ eef_delta_pos = eef[:3]
32
+ eef_ang = quaternion_to_euler(eef[3:])
33
+
34
+ # No base found
35
+
36
+ # Concatenate the action
37
+ arm_action = tf.concat([eef_delta_pos, eef_ang], axis=0)
38
+ action['arm_concat'] = arm_action
39
+
40
+ # Write the action format
41
+ action['format'] = tf.constant(
42
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_roll,eef_delta_angle_pitch,eef_delta_angle_yaw")
43
+
44
+ # Convert raw state to our state
45
+ state = step['observation']
46
+ arm_joint_pos = state['robot0_joint_pos']
47
+ arm_joint_vel = state['robot0_joint_vel']
48
+ gripper_pos = state['robot0_gripper_qpos']
49
+ gripper_vel = state['robot0_gripper_qvel']
50
+ eef_pos = state['robot0_eef_pos']
51
+ eef_ang = quaternion_to_euler(state['robot0_eef_quat'])
52
+
53
+ state['arm_concat'] = tf.concat([arm_joint_pos, arm_joint_vel, gripper_pos,gripper_vel,eef_pos,eef_ang], axis=0)
54
+ # convert to tf32
55
+ state['arm_concat'] = tf.cast(state['arm_concat'], tf.float32)
56
+ # Write the state format
57
+ state['format'] = tf.constant(
58
+ "arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,arm_joint_0_vel,arm_joint_1_vel,arm_joint_2_vel,arm_joint_3_vel,arm_joint_4_vel,arm_joint_5_vel,arm_joint_6_vel,gripper_joint_0_pos,gripper_joint_1_pos,gripper_joint_0_vel,gripper_joint_1_vel,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_roll,eef_angle_pitch,eef_angle_yaw")
59
+
60
+ # Clean the task instruction
61
+ # Define the replacements (old, new) as a dictionary
62
+ replacements = {
63
+ '_': ' ',
64
+ '1f': ' ',
65
+ '4f': ' ',
66
+ '-': ' ',
67
+ '50': ' ',
68
+ '55': ' ',
69
+ '56': ' ',
70
+
71
+ }
72
+ # manual added by lbg
73
+ instr = "lift the object on the table"
74
+ instr = clean_task_instruction(instr, replacements)
75
+ step['observation']['natural_language_instruction'] = instr
76
+
77
+ return step
78
+
79
+
80
+ if __name__ == "__main__":
81
+ import tensorflow_datasets as tfds
82
+ from data.utils import dataset_to_path
83
+
84
+ DATASET_DIR = 'data/datasets/openx_embod'
85
+ DATASET_NAME = 'roboturk'
86
+ # Load the dataset
87
+ dataset = tfds.builder_from_directory(
88
+ builder_dir=dataset_to_path(
89
+ DATASET_NAME, DATASET_DIR))
90
+ dataset = dataset.as_dataset(split='all').take(1)
91
+
92
+ # Inspect the dataset
93
+ ze=tf.constant(0.0)
94
+ for episode in dataset:
95
+ for step in episode['steps']:
96
+ print(step)
97
+ break
data/preprocess_scripts/robomimic_square_ph.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import tensorflow_datasets as tfds
3
+ from data.utils import clean_task_instruction, quaternion_to_euler
4
+
5
+
6
+ def load_dataset():
7
+ builder = tfds.builder('robomimic_ph/square_ph_image')
8
+ builder.download_and_prepare()
9
+ ds = builder.as_dataset(split='train', shuffle_files=True)
10
+ return ds
11
+
12
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
13
+ """
14
+ Convert terminate action to a boolean, where True means terminate.
15
+ """
16
+ return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
17
+
18
+ def process_step(step: dict) -> dict:
19
+ """
20
+ Unify the action format and clean the task instruction.
21
+
22
+ DO NOT use python list, use tf.TensorArray instead.
23
+ """
24
+ # format refers to https://www.tensorflow.org/datasets/catalog/robomimic_mg
25
+ # Convert raw action to our action
26
+ eef = step['action']
27
+ step['action'] = {}
28
+ action = step['action']
29
+ action['terminate'] = step['is_terminal']
30
+
31
+ eef_delta_pos = eef[:3]
32
+ eef_ang = quaternion_to_euler(eef[3:])
33
+
34
+ # No base found
35
+
36
+ # Concatenate the action
37
+ arm_action = tf.concat([eef_delta_pos, eef_ang], axis=0)
38
+ action['arm_concat'] = arm_action
39
+
40
+ # Write the action format
41
+ action['format'] = tf.constant(
42
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_roll,eef_delta_angle_pitch,eef_delta_angle_yaw")
43
+
44
+ # Convert raw state to our state
45
+ state = step['observation']
46
+ arm_joint_pos = state['robot0_joint_pos']
47
+ arm_joint_vel = state['robot0_joint_vel']
48
+ gripper_pos = state['robot0_gripper_qpos']
49
+ gripper_vel = state['robot0_gripper_qvel']
50
+ eef_pos = state['robot0_eef_pos']
51
+ eef_ang = quaternion_to_euler(state['robot0_eef_quat'])
52
+
53
+ state['arm_concat'] = tf.concat([arm_joint_pos, arm_joint_vel, gripper_pos,gripper_vel,eef_pos,eef_ang], axis=0)
54
+ # convert to tf32
55
+ state['arm_concat'] = tf.cast(state['arm_concat'], tf.float32)
56
+ # Write the state format
57
+ state['format'] = tf.constant(
58
+ "arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,arm_joint_0_vel,arm_joint_1_vel,arm_joint_2_vel,arm_joint_3_vel,arm_joint_4_vel,arm_joint_5_vel,arm_joint_6_vel,gripper_joint_0_pos,gripper_joint_1_pos,gripper_joint_0_vel,gripper_joint_1_vel,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_roll,eef_angle_pitch,eef_angle_yaw")
59
+
60
+ # Clean the task instruction
61
+ # Define the replacements (old, new) as a dictionary
62
+ replacements = {
63
+ '_': ' ',
64
+ '1f': ' ',
65
+ '4f': ' ',
66
+ '-': ' ',
67
+ '50': ' ',
68
+ '55': ' ',
69
+ '56': ' ',
70
+
71
+ }
72
+ # manual added by lbg
73
+ instr = "move the square across the cube"
74
+ instr = clean_task_instruction(instr, replacements)
75
+ step['observation']['natural_language_instruction'] = instr
76
+
77
+ return step
78
+
79
+
80
+ if __name__ == "__main__":
81
+ import tensorflow_datasets as tfds
82
+ from data.utils import dataset_to_path
83
+
84
+ DATASET_DIR = 'data/datasets/openx_embod'
85
+ DATASET_NAME = 'roboturk'
86
+ # Load the dataset
87
+ dataset = tfds.builder_from_directory(
88
+ builder_dir=dataset_to_path(
89
+ DATASET_NAME, DATASET_DIR))
90
+ dataset = dataset.as_dataset(split='all').take(1)
91
+
92
+ # Inspect the dataset
93
+ ze=tf.constant(0.0)
94
+ for episode in dataset:
95
+ for step in episode['steps']:
96
+ print(step)
97
+ break
data/preprocess_scripts/roboset.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import tensorflow_datasets as tfds
3
+ from data.utils import clean_task_instruction, quaternion_to_euler
4
+ import tensorflow as tf
5
+ import h5py
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ import os
9
+ import imageio
10
+ import concurrent.futures
11
+ import fnmatch
12
+ import cv2
13
+ import random
14
+ path2json = {
15
+ "mnt/raid5/data/jaydv/robohive_base/episodes/franka-FrankaPlanarPushReal_v2d/set_17_plannar_push_eval": [
16
+ "Push the green object to the red line."
17
+ ],
18
+ "mnt/raid5/data/jaydv/robohive_base/episodes/franka-FrankaBinReorientRealRP03_v2d_set16/set_16_bin_reorient_3": [
19
+ "Pick up the bottle and and stand it upright."
20
+ ],
21
+ "mnt/raid5/data/jaydv/robohive_base/demonstrations/orange_block/set_7_orange_block": [
22
+ "Pick up the orange block."
23
+ ],
24
+ "mnt/raid5/data/jaydv/robohive_base/demonstrations/wooden_block": [
25
+ "Pick up the wooden block."
26
+ ],
27
+ "mnt/raid5/data/jaydv/robohive_base/episodes/franka-FrankaBinReorientReal_v2d-orig/set_15_bin_reorient_eval_2": [
28
+ "Pick up the bottle and and stand it upright."
29
+ ],
30
+ "mnt/raid5/data/jaydv/robohive_base/episodes/franka-FrankaBinPushReal_v2d_set14/set_14_bin_push": [
31
+ "Push the object to the red line."
32
+ ],
33
+ "mnt/raid5/data/jaydv/robohive_base/episodes/franka-FrankaBinPickRealRP05_v2d/set_11_bottle_pick": [
34
+ "Pick up the bottle."
35
+ ],
36
+ "mnt/raid5/data/jaydv/robohive_base/episodes/franka-FrankaBinPickRealRP03_v2d/set_10_bin_pick_2004": [
37
+ "Pick up the wooden block."
38
+ ],
39
+ "mnt/raid5/data/jaydv/bin_pick_data_30029/set_2_softtoys": [
40
+ "Pick up one toy in the basket."
41
+ ],
42
+ "home/jaydv/Documents/RoboSet/pick_banana_from_toaster_place_on_table_data": [
43
+ "Pick banana from toaster and place on table."
44
+ ],
45
+ "mnt/raid5/data/jaydv/robohive_base/episodes/franka-FrankaBinPickReal_v2d/set_12_bin_pick_eval": [
46
+ "Pick up the block."
47
+ ],
48
+ "mnt/raid5/data/roboset/v0.3/flap_open_toaster_oven_data": [
49
+ "Flap open toaster."
50
+ ],
51
+ "mnt/raid5/data/jaydv/robohive_base/episodes/franka-FrankaBinReorientReal_v2d_set13/set_13_bin_reorient": [
52
+ "Pick up the bottle and stand it upright."
53
+ ],
54
+ "home/jaydv/Documents/RoboSet/drag_mug_from_right_to_left_data": [
55
+ "Drag mug right to left."
56
+ ],
57
+ "home/jaydv/Documents/RoboSet/drag_strainer_backward_data": [
58
+ "Drag strainer backwards."
59
+ ],
60
+ "mnt/raid5/data/roboset/v0.4/baking_prep/scene_4/baking_prep_slide_close_drawer_scene_4": [
61
+ "Slide and close the drawer."
62
+ ],
63
+ "mnt/raid5/data/roboset/v0.4/heat_soup/scene_4/baking_slide_in_bowl_scene_4": [
64
+ "Place the bowl into the container."
65
+ ],
66
+ "set_5_bottle_cube_14": [
67
+ "Pick up the bottle."
68
+ ],
69
+ "home/jaydv/Documents/RoboSet/drag_mug_forward_data": [
70
+ "Drag mug forwards."
71
+ ],
72
+ "mnt/raid5/data/roboset/v0.4/clean_kitchen/scene_3/clean_kitchen_pick_towel_scene_3": [
73
+ "Pick up the towel from the oven."
74
+ ],
75
+ "mnt/raid5/data/roboset/v0.4/heat_soup/scene_2/baking_pick_bowl_scene_2": [
76
+ "Pick up the bowl."
77
+ ],
78
+ "mnt/raid5/data/roboset/v0.4/heat_soup/scene_2/baking_slide_in_bowl_scene_2": [
79
+ "Place the bowl into the oven."
80
+ ],
81
+ "mnt/raid5/data/roboset/v0.4/heat_soup/scene_2/baking_close_oven_scene_2": [
82
+ "Flap and close the oven."
83
+ ],
84
+ "home/jaydv/Documents/RoboSet/pick_banana_place_in_strainer_data": [
85
+ "Pick banana and place it in strainer."
86
+ ],
87
+ "mnt/raid5/data/roboset/v0.3/pick_banana_place_in_mug_data": [
88
+ "Pick banana and place it in mug."
89
+ ],
90
+ "mnt/raid5/data/roboset/v0.4/make_tea/scene_2/make_tea_pick_tea_scene_2": [
91
+ "Pick up the tea from the container."
92
+ ],
93
+ "home/jaydv/Documents/RoboSet/drag_strainer_forward_data": [
94
+ "Drag strainer forwards."
95
+ ],
96
+ "mnt/raid5/data/roboset/v0.3/pick_ketchup_from_strainer_place_on_table_data": [
97
+ "Pick ketchup from strainer and place it on the table."
98
+ ],
99
+ "mnt/raid5/data/roboset/v0.4/baking_prep/scene_4/baking_prep_place_butter_scene_4": [
100
+ "Place the butter on the cutting board."
101
+ ],
102
+ "mnt/raid5/data/roboset/v0.4/baking_prep/scene_1/baking_prep_slide_open_drawer_scene_1": [
103
+ "Slide and open the drawer."
104
+ ],
105
+ "mnt/raid5/data/roboset/v0.3/drag_strainer_right_to_left_data": [
106
+ "Drag strainer right to left."
107
+ ],
108
+ "home/jaydv/Documents/RoboSet/pick_banana_from_plate_place_on_table_data": [
109
+ "Pick banana from plate and place on table."
110
+ ],
111
+ "mnt/raid5/data/roboset/v0.3/pick_ketchup_from_plate_place_on_table_data": [
112
+ "Pick ketchup from plate and place it on table."
113
+ ],
114
+ "mnt/raid5/data/roboset/v0.4/baking_prep/scene_1/baking_prep_slide_close_drawer_scene_1": [
115
+ "Slide and close the drawer."
116
+ ],
117
+ "mnt/raid5/data/roboset/v0.3/drag_strainer_left_to_right_data": [
118
+ "Drag strainer left to right."
119
+ ],
120
+ "pick_ketchup_place_on_toaster_data": [
121
+ "Pick ketchup from table and place on toaster."
122
+ ],
123
+ "mnt/raid5/data/roboset/v0.4/make_tea/scene_2/make_tea_place_tea_scene_2": [
124
+ "Place the tea into the cup."
125
+ ],
126
+ "mnt/raid5/data/roboset/v0.3/pick_ketchup_place_in_strainer_data": [
127
+ "Pick ketchup from the table and place it in strainer."
128
+ ],
129
+ "home/jaydv/Documents/RoboSet/pick_ketchup_place_on_plate_data": [
130
+ "Pick ketchup from table and place on plate."
131
+ ],
132
+ "home/jaydv/Documents/RoboSet/drag_mug_backward_data": [
133
+ "Drag mug backwards."
134
+ ],
135
+ "set_1_blocks_897": [
136
+ "Pick up one block in the basket."
137
+ ],
138
+ "mnt/raid5/data/roboset/v0.4/make_tea/scene_2/make_tea_place_lid_scene_2": [
139
+ "Place lid on the cutting board."
140
+ ],
141
+ "mnt/raid5/data/roboset/v0.4/heat_soup/scene_4/baking_pick_bowl_scene_4": [
142
+ "Pick up the bowl."
143
+ ],
144
+ "mnt/raid5/data/roboset/v0.4/baking_prep/scene_4/baking_prep_pick_butter_scene_4": [
145
+ "Pick up the butter from the drawer."
146
+ ],
147
+ "mnt/raid5/data/roboset/v0.4/baking_prep/scene_1/baking_prep_pick_butter_scene_1": [
148
+ "Pick up the butter from the drawer."
149
+ ],
150
+ "home/jaydv/Documents/RoboSet/flap_close_toaster_oven_data": [
151
+ "Flap close toaster."
152
+ ],
153
+ "home/jaydv/Documents/RoboSet/drag_mug_from_left_to_right_data": [
154
+ "Drag mug left to right."
155
+ ],
156
+ "set_6_planar_push_120": [
157
+ "Push the object from left to right."
158
+ ],
159
+ "mnt/raid5/data/roboset/v0.4/clean_kitchen/scene_3/clean_kitchen_slide_close_drawer_scene_3": [
160
+ "Slide and close the drawer."
161
+ ],
162
+ "set_4_med_block_7": [
163
+ "Pick up the wooden block."
164
+ ],
165
+ "mnt/raid5/data/roboset/v0.3/pick_banana_place_on_toaster_data": [
166
+ "Pick banana from table and place on toaster."
167
+ ],
168
+ "mnt/raid5/data/roboset/v0.3/pick_ketchup_from_toaster_place_on_table_data": [
169
+ "Pick ketchup from toaster and place it on table."
170
+ ],
171
+ "mnt/raid5/data/roboset/v0.4/baking_prep/scene_1/baking_prep_place_butter_scene_1": [
172
+ "Place the butter on the cutting board."
173
+ ],
174
+ "mnt/raid5/data/roboset/v0.4/baking_prep/scene_4/baking_prep_slide_open_drawer_scene_4": [
175
+ "Slide and open the drawer."
176
+ ],
177
+ "home/jaydv/Documents/RoboSet/pick_banana_place_on_plate_data": [
178
+ "Pick banana from table and place on plate."
179
+ ],
180
+ "set_8_pick_bottle_10": [
181
+ "Pick up the bottle."
182
+ ],
183
+ "home/jaydv/Documents/RoboSet/pick_ketchup_place_in_toaster_data": [
184
+ "Pick ketchup from the table and place in toaster."
185
+ ]
186
+ }
187
+
188
+ image_shape = (240, 424, 3)
189
+ Dmanus = ['']
190
+ def stash_image_into_observation(step):
191
+ step['observation'] = {'cam_high': [], 'cam_left_wrist': [], 'cam_right_wrist':[]}
192
+ step['observation']['cam_high'] = step['cam_high']
193
+ step['observation']['cam_left_wrist'] = step['cam_left_wrist']
194
+ step['observation']['cam_right_wrist'] = step['cam_right_wrist']
195
+ return step
196
+
197
+ def _parse_function(proto,instruction):
198
+ # Update the keys_to_features dictionary to match the new TFRecord format
199
+ keys_to_features = {
200
+ 'action': tf.io.FixedLenFeature([], tf.string),
201
+ 'action_gripper': tf.io.FixedLenFeature([], tf.string),
202
+ 'qpos': tf.io.FixedLenFeature([], tf.string),
203
+ 'qvel': tf.io.FixedLenFeature([], tf.string),
204
+ 'qpos_gripper': tf.io.FixedLenFeature([], tf.string),
205
+ 'qvel_gripper': tf.io.FixedLenFeature([], tf.string),
206
+ 'rgb_left': tf.io.FixedLenFeature([], tf.string),
207
+ 'rgb_right': tf.io.FixedLenFeature([], tf.string),
208
+ 'rgb_top': tf.io.FixedLenFeature([], tf.string),
209
+ 'terminate_episode': tf.io.FixedLenFeature([], tf.int64)
210
+ }
211
+
212
+ # Parse the incoming features according to the dictionary
213
+ parsed_features = tf.io.parse_single_example(proto, keys_to_features)
214
+
215
+ # Deserialize and reshape tensors as necessary
216
+ action = tf.io.parse_tensor(parsed_features['action'], out_type=tf.float16)
217
+ action_gripper = tf.io.parse_tensor(parsed_features['action_gripper'], out_type=tf.float16)
218
+ qpos = tf.io.parse_tensor(parsed_features['qpos'], out_type=tf.float16)
219
+ qvel = tf.io.parse_tensor(parsed_features['qvel'], out_type=tf.float16)
220
+ qpos_gripper = tf.io.parse_tensor(parsed_features['qpos_gripper'], out_type=tf.float16)
221
+ qvel_gripper = tf.io.parse_tensor(parsed_features['qvel_gripper'], out_type=tf.float16)
222
+ rgb_left = tf.io.parse_tensor(parsed_features['rgb_left'], out_type=tf.uint8)
223
+ rgb_right = tf.io.parse_tensor(parsed_features['rgb_right'], out_type=tf.uint8)
224
+ rgb_top = tf.io.parse_tensor(parsed_features['rgb_top'], out_type=tf.uint8)
225
+ terminate_episode = tf.cast(parsed_features['terminate_episode'], tf.int64)
226
+
227
+ # Reshape or modify other fields as needed to fit the model input
228
+ rgb_left = tf.reshape(rgb_left, image_shape)
229
+ rgb_right = tf.reshape(rgb_right, image_shape)
230
+ rgb_top = tf.reshape(rgb_top, image_shape)
231
+
232
+ return {
233
+ "action": action,
234
+ "action_gripper": action_gripper,
235
+ "qpos": qpos,
236
+ "qvel": qvel,
237
+ "qpos_gripper": qpos_gripper,
238
+ "qvel_gripper": qvel_gripper,
239
+ "observation": {
240
+ "rgb_left": rgb_left,
241
+ "rgb_right": rgb_right,
242
+ "rgb_top": rgb_top
243
+ },
244
+ "terminate_episode": terminate_episode,
245
+ "instruction": instruction
246
+ }
247
+
248
+
249
+ def dataset_generator_from_tfrecords(seed):
250
+ tfrecord_path = './data/datasets/roboset/tfrecords/'
251
+ failure = [f'set_{i}' for i in range(10, 18)]
252
+ filepaths = []
253
+ for root, dirs, files in os.walk(tfrecord_path):
254
+ # skip datasets with failure
255
+ fail = False
256
+ for f in failure:
257
+ if f in root:
258
+ fail = True
259
+ break
260
+ if fail:
261
+ continue
262
+
263
+ for filename in fnmatch.filter(files, '*.tfrecord'):
264
+ filepath = os.path.join(root, filename)
265
+ filepaths.append(filepath)
266
+
267
+ random.seed(seed)
268
+ random.shuffle(filepaths)
269
+ for filepath in filepaths:
270
+ for path in path2json:
271
+ if path in filepath:
272
+ instruction = path2json[path]
273
+ raw_dataset = tf.data.TFRecordDataset(filepath)
274
+ dataset = raw_dataset.map(lambda x: _parse_function(x,instruction))
275
+ yield {
276
+ 'steps': dataset
277
+ }
278
+
279
+ def load_dataset(seed):
280
+ dataset = tf.data.Dataset.from_generator(
281
+ lambda: dataset_generator_from_tfrecords(seed),
282
+ output_signature={
283
+ 'steps': tf.data.DatasetSpec(
284
+ element_spec={
285
+ 'action': tf.TensorSpec(shape=(None), dtype=tf.float16),
286
+ 'action_gripper': tf.TensorSpec(shape=(None), dtype=tf.float16),
287
+ 'qpos': tf.TensorSpec(shape=(None), dtype=tf.float16),
288
+ 'qvel': tf.TensorSpec(shape=(None), dtype=tf.float16),
289
+ 'qpos_gripper': tf.TensorSpec(shape=(None), dtype=tf.float16),
290
+ 'qvel_gripper': tf.TensorSpec(shape=(None), dtype=tf.float16),
291
+ 'observation': {
292
+ 'rgb_left': tf.TensorSpec(shape=image_shape, dtype=tf.uint8),
293
+ 'rgb_right': tf.TensorSpec(shape=image_shape, dtype=tf.uint8),
294
+ 'rgb_top': tf.TensorSpec(shape=image_shape, dtype=tf.uint8),
295
+ },
296
+ 'terminate_episode': tf.TensorSpec(shape=(), dtype=tf.int64),
297
+ 'instruction': tf.TensorSpec(shape=(None), dtype=tf.string)
298
+ }
299
+ )
300
+ }
301
+ )
302
+
303
+ return dataset
304
+
305
+
306
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
307
+ """
308
+ Convert terminate action to a boolean, where True means terminate.
309
+ """
310
+ return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float16)),tf.constant(False),tf.constant(True))
311
+
312
+
313
+ def process_step(step: dict) -> dict:
314
+ """
315
+ Unify the action format and clean the task instruction.
316
+
317
+ DO NOT use python list, use tf.TensorArray instead.
318
+ """
319
+ # Convert raw action to our action
320
+ step['action'] = {}
321
+ step['action']['terminate'] = step['terminate_episode']
322
+ # undetermined action
323
+
324
+ state = step['observation']
325
+ qpos = tf.cast(step['qpos'], tf.float32)
326
+ # qvel = tf.cast(step['qvel'], tf.float32)
327
+ gripper_pos = tf.expand_dims(tf.cast(step['qpos_gripper'], tf.float32), axis=0)
328
+ # delete due to all zeros
329
+ # gripper_vel = tf.expand_dims(tf.cast(step['qvel_gripper'], tf.float32), axis=0)
330
+
331
+ # state['arm_concat'] = tf.concat([qpos, qvel, gripper_pos, gripper_vel], axis=0)
332
+ state['arm_concat'] = tf.concat([qpos, gripper_pos], axis=0)
333
+ # state['format'] = tf.constant(
334
+ # "arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,arm_joint_0_vel,arm_joint_1_vel,arm_joint_2_vel,arm_joint_3_vel,arm_joint_4_vel,arm_joint_5_vel,arm_joint_6_vel,gripper_joint_0_pos,gripper_joint_0_vel"
335
+ # )
336
+ state['format'] = tf.constant(
337
+ "arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,gripper_joint_0_pos"
338
+ )
339
+ # Clean the task instruction
340
+ # Define the replacements (old, new) as a dictionary
341
+ replacements = {
342
+ '_': ' ',
343
+ '1f': ' ',
344
+ '4f': ' ',
345
+ '-': ' ',
346
+ '50': ' ',
347
+ '55': ' ',
348
+ '56': ' ',
349
+
350
+ }
351
+ instr = step['instruction'][0]
352
+ instr = clean_task_instruction(instr, replacements)
353
+ step['observation']['natural_language_instruction'] = instr
354
+
355
+ return step
356
+
357
+
358
+ if __name__ == "__main__":
359
+ import tensorflow_datasets as tfds
360
+ from data.utils import dataset_to_path
361
+
362
+ dataset = load_dataset()
363
+ for step in dataset.take(100):
364
+ for data in step['steps']:
365
+ data = process_step(data)
366
+ print(data)
367
+ break
data/preprocess_scripts/roboturk.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, quaternion_to_euler, euler_to_quaternion
4
+
5
+
6
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
7
+ """
8
+ Convert terminate action to a boolean, where True means terminate.
9
+ """
10
+ return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
11
+
12
+
13
+ def process_step(step: dict) -> dict:
14
+ """
15
+ Unify the action format and clean the task instruction.
16
+
17
+ DO NOT use python list, use tf.TensorArray instead.
18
+ """
19
+ # Convert raw action to our action
20
+ action = step['action']
21
+ action['terminate'] = terminate_act_to_bool(action['terminate_episode'])
22
+
23
+ eef_delta_pos = action['world_vector']
24
+ eef_ang = action['rotation_delta']
25
+ eef_ang = euler_to_quaternion(eef_ang)
26
+
27
+ grip_open = tf.where(action['gripper_closedness_action']<0,tf.constant(1.0),tf.constant(0.0))
28
+
29
+ # No base found
30
+
31
+ # Concatenate the action
32
+ arm_action = tf.concat([eef_delta_pos, eef_ang, grip_open], axis=0)
33
+ action['arm_concat'] = arm_action
34
+
35
+ # Write the action format
36
+ action['format'] = tf.constant(
37
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
38
+
39
+ # No state found
40
+
41
+ # Clean the task instruction
42
+ # Define the replacements (old, new) as a dictionary
43
+ replacements = {
44
+ '_': ' ',
45
+ '1f': ' ',
46
+ '4f': ' ',
47
+ '-': ' ',
48
+ '50': ' ',
49
+ '55': ' ',
50
+ '56': ' ',
51
+
52
+ }
53
+ instr = step['observation']['natural_language_instruction']
54
+ instr = clean_task_instruction(instr, replacements)
55
+ step['observation']['natural_language_instruction'] = instr
56
+
57
+ return step
58
+
59
+
60
+ if __name__ == "__main__":
61
+ import tensorflow_datasets as tfds
62
+ from data.utils import dataset_to_path
63
+
64
+ DATASET_DIR = 'data/datasets/openx_embod'
65
+ DATASET_NAME = 'roboturk'
66
+ # Load the dataset
67
+ dataset = tfds.builder_from_directory(
68
+ builder_dir=dataset_to_path(
69
+ DATASET_NAME, DATASET_DIR))
70
+ dataset = dataset.as_dataset(split='all').take(1)
71
+
72
+ # Inspect the dataset
73
+ ze=tf.constant(0.0)
74
+ for episode in dataset:
75
+ for step in episode['steps']:
76
+ print(step)
77
+ break
data/preprocess_scripts/roboturk_real_objectsearch.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import tensorflow_datasets as tfds
3
+ from data.utils import clean_task_instruction, quaternion_to_euler
4
+ import tensorflow as tf
5
+ import h5py
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ import os
9
+ import imageio
10
+ import concurrent.futures
11
+
12
+ def get_frames(file_path):
13
+ if not os.path.exists(file_path) or not os.path.isfile(file_path) or not file_path.endswith('.mp4'):
14
+ return []
15
+ frames = []
16
+ with imageio.get_reader(file_path, 'ffmpeg') as reader:
17
+ for frame in reader:
18
+ frame = np.array(frame, dtype=np.uint8)
19
+ frames.append(frame)
20
+ return frames
21
+
22
+ def parallel_get_frames(paths):
23
+ with concurrent.futures.ThreadPoolExecutor() as executor:
24
+ future_to_path = {executor.submit(get_frames, path): path for path in paths}
25
+ return [future.result() for future in concurrent.futures.as_completed(future_to_path)]
26
+
27
+ def count_total_samples(filename):
28
+ total_samples = 0
29
+ with h5py.File(filename, 'r') as f:
30
+ data = f['data']
31
+ for user_key in data.keys():
32
+ user = data[user_key]
33
+ for demo_key in user.keys():
34
+ total_samples += 1
35
+ return total_samples
36
+
37
+ def dataset_generator(filename, total_samples):
38
+ with h5py.File(filename, 'r') as f:
39
+ data = f['data']
40
+ for user_key in data.keys():
41
+ user = data[user_key]
42
+ for demo_key in user.keys():
43
+ demo = user[demo_key]
44
+ robot_observation = demo['robot_observation']
45
+ user_control = demo['user_control']
46
+
47
+ eef_poses = robot_observation['eef_poses']
48
+ joint_states_arm = robot_observation['joint_states_arm']
49
+ joint_states_gripper = robot_observation['joint_states_gripper']
50
+ user_control_data = user_control['user_control']
51
+
52
+ attrs = dict(demo.attrs)
53
+ top_depth_video_file = attrs['top_depth_video_file']
54
+ top_rgb_video_file = attrs['top_rgb_video_file']
55
+ front_rgb_video_file = attrs['front_rgb_video_file']
56
+
57
+ video_root_path = './data/datasets/roboturk/'
58
+ top_depth_frames = get_frames(os.path.join(video_root_path, top_depth_video_file))
59
+ top_rgb_frames = get_frames(os.path.join(video_root_path, top_rgb_video_file))
60
+ front_rgb_frames = get_frames(os.path.join(video_root_path, front_rgb_video_file))
61
+
62
+ if len(top_rgb_frames) == 0 or len(front_rgb_frames) == 0:
63
+ continue
64
+
65
+ steps = []
66
+ for i in range(len(eef_poses)):
67
+ task_demo_id = f"SawyerTowerCreation_{demo_key}_{i}"
68
+ step = {
69
+ 'task_demo_id': task_demo_id,
70
+ 'eef_poses': eef_poses[i],
71
+ 'joint_states_arm': joint_states_arm[i],
72
+ 'joint_states_gripper': joint_states_gripper[i],
73
+ 'user_control': user_control_data[i] if user_control_data.shape[0] > 0 else np.zeros(22),
74
+ 'observation':{
75
+ 'top_depth_frame': top_depth_frames[i] if i < len(top_depth_frames) else np.zeros((0,0, 3), dtype=np.uint8),
76
+ 'top_rgb_frame': top_rgb_frames[i] if i < len(top_rgb_frames) else np.zeros((0, 0, 3), dtype=np.uint8),
77
+ 'front_rgb_frame': front_rgb_frames[i] if i < len(front_rgb_frames) else np.zeros((0, 0, 3), dtype=np.uint8),
78
+ },
79
+ 'terminate_episode': i == len(eef_poses) - 1
80
+ }
81
+ steps.append(step)
82
+
83
+
84
+ steps_dataset = tf.data.Dataset.from_generator(
85
+ lambda: iter(steps),
86
+ output_signature={
87
+ 'task_demo_id': tf.TensorSpec(shape=(), dtype=tf.string),
88
+ 'eef_poses': tf.TensorSpec(shape=(7,), dtype=tf.float32),
89
+ 'joint_states_arm': tf.TensorSpec(shape=(27,), dtype=tf.float32),
90
+ 'joint_states_gripper': tf.TensorSpec(shape=(3,), dtype=tf.float32),
91
+ 'user_control': tf.TensorSpec(shape=(22,), dtype=tf.float32),
92
+ 'observation':{
93
+ 'top_depth_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
94
+ 'top_rgb_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
95
+ 'front_rgb_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
96
+ },
97
+ 'terminate_episode': tf.TensorSpec(shape=(), dtype=tf.bool),
98
+ }
99
+ )
100
+
101
+ yield {'steps': steps_dataset}
102
+
103
+ def load_dataset():
104
+ filename = './data/datasets/roboturk/SawyerObjectSearch_aligned_dataset.hdf5'
105
+ total_samples = count_total_samples(filename)
106
+ dataset = tf.data.Dataset.from_generator(
107
+ lambda: dataset_generator(filename, total_samples),
108
+ output_signature={
109
+ 'steps': tf.data.DatasetSpec(
110
+ element_spec={
111
+ 'task_demo_id': tf.TensorSpec(shape=(), dtype=tf.string),
112
+ 'eef_poses': tf.TensorSpec(shape=(7,), dtype=tf.float32),
113
+ 'joint_states_arm': tf.TensorSpec(shape=(27,), dtype=tf.float32),
114
+ 'joint_states_gripper': tf.TensorSpec(shape=(3,), dtype=tf.float32),
115
+ 'user_control': tf.TensorSpec(shape=(22,), dtype=tf.float32),
116
+ 'observation':{
117
+ 'top_depth_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
118
+ 'top_rgb_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
119
+ 'front_rgb_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
120
+ },
121
+ 'terminate_episode': tf.TensorSpec(shape=(), dtype = tf.bool),
122
+ }
123
+ )
124
+ }
125
+ )
126
+ return dataset
127
+
128
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
129
+ """
130
+ Convert terminate action to a boolean, where True means terminate.
131
+ """
132
+ return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
133
+
134
+
135
+ def process_step(step: dict) -> dict:
136
+ """
137
+ Unify the action format and clean the task instruction.
138
+
139
+ DO NOT use python list, use tf.TensorArray instead.
140
+ """
141
+ # Convert raw action to our action
142
+ step['action'] = {}
143
+ action = step['action']
144
+ action['terminate'] = step['terminate_episode']
145
+
146
+ eef_delta_pos = step['eef_poses'][:3]
147
+ eef_ang = step['eef_poses'][3:]
148
+
149
+ # No base found
150
+ # Concatenate the action
151
+ arm_action = tf.concat([eef_delta_pos, eef_ang], axis=0)
152
+ action['arm_concat'] = arm_action
153
+
154
+ # Write the action format
155
+ action['format'] = tf.constant(
156
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w")
157
+
158
+ # No state found
159
+ state = step['observation']
160
+ # joint_states_arm: dataset of (num_timestamps, 27) shape where each of the 9 joints is represented by the JointState message
161
+ # (the nine joints are in order by their ROSBAG names: ['head_pan', 'right_j0', 'right_j1', 'right_j2', 'right_j3', 'right_j4', 'right_j5', 'right_j6', 'torso_t0']. For the most part, head_pan and torso should be zeros)
162
+ # [0] the position of the first joint (rad or m)
163
+ # [1] the velocity of the first joint (rad/s or m/s)
164
+ # [2] the effort that is applied in the first joint
165
+ # [3] the position of the second joint...
166
+ joint_states_arm = step['joint_states_arm']
167
+ joint_pos = joint_states_arm[3:24:3]
168
+ joint_vel = joint_states_arm[4:25:3]
169
+ # joint_states_gripper: dataset of (num_timestamps, 3) shape
170
+ # [0] the position of the gripper (rad or m)
171
+ # [1] the velocity of the gripper (rad/s or m/s)
172
+ # [2] the effort that is applied in the gripper
173
+ joint_states_gripper = step['joint_states_gripper']
174
+ gripper_pos = joint_states_gripper[:1]
175
+ # remove gripper_vel due to they are all zeros
176
+ # gripper_vel = joint_states_gripper[1:2]
177
+ # Concatenate the state
178
+ # state['arm_concat'] = tf.concat([joint_pos,joint_vel,gripper_pos,gripper_vel], axis=0)
179
+ state['arm_concat'] = tf.concat([joint_pos,joint_vel,gripper_pos], axis=0)
180
+ # Write the state format
181
+ state['format'] = tf.constant(
182
+ "arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,arm_joint_0_vel,arm_joint_1_vel,arm_joint_2_vel,arm_joint_3_vel,arm_joint_4_vel,arm_joint_5_vel,arm_joint_6_vel,gripper_joint_0_pos")
183
+
184
+
185
+ # Clean the task instruction
186
+ # Define the replacements (old, new) as a dictionary
187
+ replacements = {
188
+ '_': ' ',
189
+ '1f': ' ',
190
+ '4f': ' ',
191
+ '-': ' ',
192
+ '50': ' ',
193
+ '55': ' ',
194
+ '56': ' ',
195
+
196
+ }
197
+ # copied from openxembod
198
+ instr = b'create tower'
199
+ instr = clean_task_instruction(instr, replacements)
200
+ step['observation']['natural_language_instruction'] = instr
201
+
202
+ return step
203
+
204
+
205
+ if __name__ == "__main__":
206
+ import tensorflow_datasets as tfds
207
+ from data.utils import dataset_to_path
208
+
209
+ DATASET_DIR = '/cephfs-thu/gsm_data/openx_embod'
210
+ DATASET_NAME = 'roboturk_real_laundrylayout'
211
+ # Load the dataset
212
+ dataset = load_dataset()
213
+
214
+ # save_dir = os.path.join(DATASET_DIR, DATASET_NAME)
215
+ # if not os.path.exists(save_dir):
216
+ # os.makedirs(save_dir)
217
+ # tf.data.experimental.save(dataset, save_dir)
data/preprocess_scripts/roboturk_real_towercreation.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import tensorflow_datasets as tfds
3
+ from data.utils import clean_task_instruction, quaternion_to_euler
4
+ import tensorflow as tf
5
+ import h5py
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ import os
9
+ import imageio
10
+ import concurrent.futures
11
+
12
+ def get_frames(file_path):
13
+ if not os.path.exists(file_path) or not os.path.isfile(file_path) or not file_path.endswith('.mp4'):
14
+ return []
15
+ frames = []
16
+ with imageio.get_reader(file_path, 'ffmpeg') as reader:
17
+ for frame in reader:
18
+ frame = np.array(frame, dtype=np.uint8)
19
+ frames.append(frame)
20
+ return frames
21
+
22
+ def parallel_get_frames(paths):
23
+ with concurrent.futures.ThreadPoolExecutor() as executor:
24
+ future_to_path = {executor.submit(get_frames, path): path for path in paths}
25
+ return [future.result() for future in concurrent.futures.as_completed(future_to_path)]
26
+
27
+ def count_total_samples(filename):
28
+ total_samples = 0
29
+ with h5py.File(filename, 'r') as f:
30
+ data = f['data']
31
+ for user_key in data.keys():
32
+ user = data[user_key]
33
+ for demo_key in user.keys():
34
+ total_samples += 1
35
+ return total_samples
36
+
37
+ def dataset_generator(filename, total_samples):
38
+ with h5py.File(filename, 'r') as f:
39
+ data = f['data']
40
+ for user_key in data.keys():
41
+ user = data[user_key]
42
+ for demo_key in user.keys():
43
+ demo = user[demo_key]
44
+ robot_observation = demo['robot_observation']
45
+ user_control = demo['user_control']
46
+
47
+ eef_poses = robot_observation['eef_poses']
48
+ joint_states_arm = robot_observation['joint_states_arm']
49
+ joint_states_gripper = robot_observation['joint_states_gripper']
50
+ user_control_data = user_control['user_control']
51
+
52
+ attrs = dict(demo.attrs)
53
+ top_depth_video_file = attrs['top_depth_video_file']
54
+ top_rgb_video_file = attrs['top_rgb_video_file']
55
+ front_rgb_video_file = attrs['front_rgb_video_file']
56
+
57
+ video_root_path = './data/datasets/roboturk/'
58
+ top_depth_frames = get_frames(os.path.join(video_root_path, top_depth_video_file))
59
+ top_rgb_frames = get_frames(os.path.join(video_root_path, top_rgb_video_file))
60
+ front_rgb_frames = get_frames(os.path.join(video_root_path, front_rgb_video_file))
61
+
62
+ if len(top_rgb_frames) == 0 or len(front_rgb_frames) == 0:
63
+ continue
64
+ # video_root_path = '/cephfs-thu/gsm_data/robotruck'
65
+ # video_paths = [
66
+ # os.path.join(video_root_path, attrs['top_depth_video_file']),
67
+ # os.path.join(video_root_path, attrs['top_rgb_video_file']),
68
+ # os.path.join(video_root_path, attrs['front_rgb_video_file'])
69
+ # ]
70
+ # top_depth_frames, top_rgb_frames, front_rgb_frames = parallel_get_frames(video_paths)
71
+
72
+ steps = []
73
+ for i in range(len(eef_poses)):
74
+ task_demo_id = f"SawyerTowerCreation_{demo_key}_{i}"
75
+ step = {
76
+ 'task_demo_id': task_demo_id,
77
+ 'eef_poses': eef_poses[i],
78
+ 'joint_states_arm': joint_states_arm[i],
79
+ 'joint_states_gripper': joint_states_gripper[i],
80
+ 'user_control': user_control_data[i] if user_control_data.shape[0] > 0 else np.zeros(22),
81
+ 'observation':{
82
+ 'top_depth_frame': top_depth_frames[i] if i < len(top_depth_frames) else np.zeros((0,0, 3), dtype=np.uint8),
83
+ 'top_rgb_frame': top_rgb_frames[i] if i < len(top_rgb_frames) else np.zeros((0, 0, 3), dtype=np.uint8),
84
+ 'front_rgb_frame': front_rgb_frames[i] if i < len(front_rgb_frames) else np.zeros((0, 0, 3), dtype=np.uint8),
85
+ },
86
+ 'terminate_episode': i == len(eef_poses) - 1
87
+ }
88
+ steps.append(step)
89
+
90
+
91
+ steps_dataset = tf.data.Dataset.from_generator(
92
+ lambda: iter(steps),
93
+ output_signature={
94
+ 'task_demo_id': tf.TensorSpec(shape=(), dtype=tf.string),
95
+ 'eef_poses': tf.TensorSpec(shape=(7,), dtype=tf.float32),
96
+ 'joint_states_arm': tf.TensorSpec(shape=(27,), dtype=tf.float32),
97
+ 'joint_states_gripper': tf.TensorSpec(shape=(3,), dtype=tf.float32),
98
+ 'user_control': tf.TensorSpec(shape=(22,), dtype=tf.float32),
99
+ 'observation':{
100
+ 'top_depth_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
101
+ 'top_rgb_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
102
+ 'front_rgb_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
103
+ },
104
+ 'terminate_episode': tf.TensorSpec(shape=(), dtype=tf.bool),
105
+ }
106
+ )
107
+
108
+ yield {'steps': steps_dataset}
109
+
110
+ def load_dataset():
111
+ filename = './data/datasets/roboturk/SawyerTowerCreation_aligned_dataset.hdf5'
112
+ total_samples = count_total_samples(filename)
113
+ dataset = tf.data.Dataset.from_generator(
114
+ lambda: dataset_generator(filename, total_samples),
115
+ output_signature={
116
+ 'steps': tf.data.DatasetSpec(
117
+ element_spec={
118
+ 'task_demo_id': tf.TensorSpec(shape=(), dtype=tf.string),
119
+ 'eef_poses': tf.TensorSpec(shape=(7,), dtype=tf.float32),
120
+ 'joint_states_arm': tf.TensorSpec(shape=(27,), dtype=tf.float32),
121
+ 'joint_states_gripper': tf.TensorSpec(shape=(3,), dtype=tf.float32),
122
+ 'user_control': tf.TensorSpec(shape=(22,), dtype=tf.float32),
123
+ 'observation':{
124
+ 'top_depth_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
125
+ 'top_rgb_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
126
+ 'front_rgb_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
127
+ },
128
+ 'terminate_episode': tf.TensorSpec(shape=(), dtype = tf.bool),
129
+ }
130
+ )
131
+ }
132
+ )
133
+ return dataset
134
+
135
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
136
+ """
137
+ Convert terminate action to a boolean, where True means terminate.
138
+ """
139
+ return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
140
+
141
+
142
+ def process_step(step: dict) -> dict:
143
+ """
144
+ Unify the action format and clean the task instruction.
145
+
146
+ DO NOT use python list, use tf.TensorArray instead.
147
+ """
148
+ # Convert raw action to our action
149
+ step['action'] = {}
150
+ action = step['action']
151
+ action['terminate'] = step['terminate_episode']
152
+
153
+ eef_delta_pos = step['eef_poses'][:3]
154
+ eef_ang = quaternion_to_euler(step['eef_poses'][3:])
155
+
156
+ # No base found
157
+ # Concatenate the action
158
+ arm_action = tf.concat([eef_delta_pos, eef_ang], axis=0)
159
+ action['arm_concat'] = arm_action
160
+
161
+ # Write the action format
162
+ action['format'] = tf.constant(
163
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_roll,eef_delta_angle_pitch,eef_delta_angle_yaw")
164
+
165
+ # No state found
166
+ state = step['observation']
167
+ # joint_states_arm: dataset of (num_timestamps, 27) shape where each of the 9 joints is represented by the JointState message
168
+ # (the nine joints are in order by their ROSBAG names: ['head_pan', 'right_j0', 'right_j1', 'right_j2', 'right_j3', 'right_j4', 'right_j5', 'right_j6', 'torso_t0']. For the most part, head_pan and torso should be zeros)
169
+ # [0] the position of the first joint (rad or m)
170
+ # [1] the velocity of the first joint (rad/s or m/s)
171
+ # [2] the effort that is applied in the first joint
172
+ # [3] the position of the second joint...
173
+ joint_states_arm = step['joint_states_arm']
174
+ joint_pos = joint_states_arm[3:24:3]
175
+ joint_vel = joint_states_arm[4:25:3]
176
+ # joint_states_gripper: dataset of (num_timestamps, 3) shape
177
+ # [0] the position of the gripper (rad or m)
178
+ # [1] the velocity of the gripper (rad/s or m/s)
179
+ # [2] the effort that is applied in the gripper
180
+ joint_states_gripper = step['joint_states_gripper']
181
+ gripper_pos = joint_states_gripper[:1]
182
+ # remove gripper_vel due to they are all zeros
183
+ # gripper_vel = joint_states_gripper[1:2]
184
+ # Concatenate the state
185
+ # state['arm_concat'] = tf.concat([joint_pos,joint_vel,gripper_pos,gripper_vel], axis=0)
186
+ state['arm_concat'] = tf.concat([joint_pos,joint_vel,gripper_pos], axis=0)
187
+ # Write the state format
188
+ state['format'] = tf.constant(
189
+ "arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,arm_joint_0_vel,arm_joint_1_vel,arm_joint_2_vel,arm_joint_3_vel,arm_joint_4_vel,arm_joint_5_vel,arm_joint_6_vel,gripper_joint_0_pos")
190
+
191
+ # Clean the task instruction
192
+ # Define the replacements (old, new) as a dictionary
193
+ replacements = {
194
+ '_': ' ',
195
+ '1f': ' ',
196
+ '4f': ' ',
197
+ '-': ' ',
198
+ '50': ' ',
199
+ '55': ' ',
200
+ '56': ' ',
201
+
202
+ }
203
+ # copied from openxembod
204
+ instr = b'create tower'
205
+ instr = clean_task_instruction(instr, replacements)
206
+ step['observation']['natural_language_instruction'] = instr
207
+
208
+ return step
209
+
210
+
211
+ if __name__ == "__main__":
212
+ import tensorflow_datasets as tfds
213
+ from data.utils import dataset_to_path
214
+
215
+ DATASET_DIR = '/cephfs-thu/gsm_data/openx_embod'
216
+ DATASET_NAME = 'roboturk_real_laundrylayout'
217
+ # Load the dataset
218
+ dataset = load_dataset()
219
+
220
+ # save_dir = os.path.join(DATASET_DIR, DATASET_NAME)
221
+ # if not os.path.exists(save_dir):
222
+ # os.makedirs(save_dir)
223
+ # tf.data.experimental.save(dataset, save_dir)
data/preprocess_scripts/stanford_hydra_dataset_converted_externally_to_rlds.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, euler_to_quaternion, euler_to_rotation_matrix,\
4
+ rotation_matrix_to_ortho6d
5
+
6
+
7
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
8
+ """
9
+ Convert terminate action to a boolean, where True means terminate.
10
+ """
11
+ return tf.reduce_all(tf.equal(terminate_act, tf.constant([1, 0, 0], dtype=tf.int32)))
12
+
13
+
14
+ def process_step(step: dict) -> dict:
15
+ """
16
+ Unify the action format and clean the task instruction.
17
+
18
+ DO NOT use python list, use tf.TensorArray instead.
19
+ """
20
+ # Convert raw action to our action
21
+ action = step['action']
22
+ eef_delta_pos = action[:3]
23
+ eef_ang = action[3:6]
24
+ eef_ang = euler_to_quaternion(eef_ang)
25
+ grip_open = tf.expand_dims(1 - action[6], axis=0)
26
+
27
+ # Concatenate the action
28
+ # action['arm_concat'] = tf.concat([eef_delta_pos, eef_ang, grip_open], axis=0)
29
+ step['action'] = {}
30
+ action = step['action']
31
+ action['arm_concat'] = tf.concat([eef_delta_pos, eef_ang, grip_open], axis=0)
32
+ action['terminate'] = step['is_terminal']
33
+
34
+ # Write the action format
35
+ action['format'] = tf.constant(
36
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
37
+
38
+ # Convert raw state to our state
39
+ state = step['observation']
40
+ state_vec = state['state']
41
+ # Robot state, consists of [3x EEF position,4x EEF orientation in quaternion,3x EEF orientation in euler angle,7x robot joint angles, 7x robot joint velocities,3x gripper state.
42
+ arm_joint_pos = state_vec[10:17]
43
+ arm_joint_vel = state_vec[17:24]
44
+ eef_pos = state_vec[:3]
45
+ eef_ang = state_vec[7:10]
46
+ eef_ang = euler_to_rotation_matrix(eef_ang)
47
+ eef_ang = rotation_matrix_to_ortho6d(eef_ang)
48
+ # Rescale gripper width to [0, 1]
49
+ grip_joint_pos = tf.concat([
50
+ state_vec[24:25] * 12.324, state_vec[25:27]
51
+ ], axis=0)
52
+
53
+ # Concatenate the state
54
+ state['arm_concat'] = tf.concat([arm_joint_pos, grip_joint_pos, arm_joint_vel, eef_pos, eef_ang], axis=0)
55
+
56
+ # Write the state format
57
+ state['format'] = tf.constant(
58
+ "arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,gripper_joint_0_pos,gripper_joint_1_pos,gripper_joint_2_pos,arm_joint_0_vel,arm_joint_1_vel,arm_joint_2_vel,arm_joint_3_vel,arm_joint_4_vel,arm_joint_5_vel,arm_joint_6_vel,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
59
+
60
+ # Clean the task instruction
61
+ # Define the replacements (old, new) as a dictionary
62
+ replacements = {
63
+ '_': ' ',
64
+ '1f': ' ',
65
+ '4f': ' ',
66
+ '-': ' ',
67
+ '50': ' ',
68
+ '55': ' ',
69
+ '56': ' ',
70
+
71
+ }
72
+ instr = step['language_instruction']
73
+ instr = clean_task_instruction(instr, replacements)
74
+ step['observation']['natural_language_instruction'] = instr
75
+
76
+ return step
77
+
78
+
79
+ if __name__ == "__main__":
80
+ import tensorflow_datasets as tfds
81
+ from data.utils import dataset_to_path
82
+
83
+ DATASET_DIR = 'data/datasets/openx_embod'
84
+ DATASET_NAME = 'fractal20220817_data'
85
+ # Load the dataset
86
+ dataset = tfds.builder_from_directory(
87
+ builder_dir=dataset_to_path(
88
+ DATASET_NAME, DATASET_DIR))
89
+ dataset = dataset.as_dataset(split='all')
90
+
91
+ # Inspect the dataset
92
+ for episode in dataset:
93
+ for step in episode['steps']:
94
+ print(step)
data/preprocess_scripts/tokyo_u_lsmo_converted_externally_to_rlds.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, euler_to_quaternion, euler_to_rotation_matrix, \
4
+ rotation_matrix_to_ortho6d
5
+
6
+
7
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
8
+ """
9
+ Convert terminate action to a boolean, where True means terminate.
10
+ """
11
+ return tf.reduce_all(tf.equal(terminate_act, tf.constant([1, 0, 0], dtype=tf.int32)))
12
+
13
+
14
+ def process_step(step: dict) -> dict:
15
+ """
16
+ Unify the action format and clean the task instruction.
17
+
18
+ DO NOT use python list, use tf.TensorArray instead.
19
+ """
20
+ # Convert raw action to our action
21
+ action = step['action']
22
+ # Robot action, consists of [3x endeffector position, 3x euler angles,1x gripper action].
23
+ eef_delta_pos = action[:3]
24
+ eef_ang = action[3:6]
25
+ eef_ang = euler_to_quaternion(eef_ang)
26
+ grip_open = tf.expand_dims(1 - action[6], axis=0)
27
+
28
+ # Concatenate the action
29
+ step['action'] = {}
30
+ action = step['action']
31
+ action['arm_concat'] = tf.concat([eef_delta_pos, eef_ang, grip_open], axis=0)
32
+ action['terminate'] = step['is_terminal']
33
+
34
+ # Write the action format
35
+ action['format'] = tf.constant(
36
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
37
+
38
+ # Convert raw state to our state
39
+ state = step['observation']
40
+ state_vec = state['state']
41
+ # Robot state, consists of [3x endeffector position, 3x euler angles,6x robot joint angles, 1x gripper position].
42
+ eef_pos = state_vec[:3]
43
+ eef_ang = state_vec[3:6]
44
+ eef_ang = euler_to_rotation_matrix(eef_ang)
45
+ eef_ang = rotation_matrix_to_ortho6d(eef_ang)
46
+ arm_joint_ang = state_vec[6:12]
47
+ grip_joint_pos = 1 - state_vec[12:13]
48
+
49
+ # Concatenate the state
50
+ state['arm_concat'] = tf.concat([arm_joint_ang, grip_joint_pos, eef_pos, eef_ang], axis=0)
51
+
52
+ # Write the state format
53
+ state['format'] = tf.constant(
54
+ "arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,gripper_joint_0_pos,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
55
+
56
+ # Clean the task instruction
57
+ # Define the replacements (old, new) as a dictionary
58
+ replacements = {
59
+ '_': ' ',
60
+ '1f': ' ',
61
+ '4f': ' ',
62
+ '-': ' ',
63
+ '50': ' ',
64
+ '55': ' ',
65
+ '56': ' ',
66
+
67
+ }
68
+ instr = step['language_instruction']
69
+ instr = clean_task_instruction(instr, replacements)
70
+ step['observation']['natural_language_instruction'] = instr
71
+
72
+ return step
73
+
74
+
75
+ if __name__ == "__main__":
76
+ import tensorflow_datasets as tfds
77
+ from data.utils import dataset_to_path
78
+
79
+ DATASET_DIR = 'data/datasets/openx_embod'
80
+ DATASET_NAME = 'fractal20220817_data'
81
+ # Load the dataset
82
+ dataset = tfds.builder_from_directory(
83
+ builder_dir=dataset_to_path(
84
+ DATASET_NAME, DATASET_DIR))
85
+ dataset = dataset.as_dataset(split='all')
86
+
87
+ # Inspect the dataset
88
+ for episode in dataset:
89
+ for step in episode['steps']:
90
+ print(step)
data/preprocess_scripts/utokyo_pr2_opening_fridge_converted_externally_to_rlds.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, euler_to_quaternion, euler_to_rotation_matrix, \
4
+ rotation_matrix_to_ortho6d
5
+
6
+
7
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
8
+ """
9
+ Convert terminate action to a boolean, where True means terminate.
10
+ """
11
+ return tf.reduce_all(tf.equal(terminate_act, tf.constant([1, 0, 0], dtype=tf.float32)))
12
+
13
+
14
+ def process_step(step: dict) -> dict:
15
+ """
16
+ Unify the action format and clean the task instruction.
17
+
18
+ DO NOT use python list, use tf.TensorArray instead.
19
+ """
20
+
21
+ # Convert raw action to our action
22
+ action = step['action']
23
+ # Robot action, consists of [3x end effector pos, 3x robot rpy angles, 1x gripper open/close command, 1x terminal action].
24
+ eef_delta_pos = action[:3]/1000 # change from mm to m
25
+ eef_ang = action[3:6]
26
+ eef_ang = euler_to_quaternion(eef_ang)
27
+ grip_open = tf.expand_dims(1 - action[6], axis=0)
28
+
29
+ # Concatenate the action
30
+ step['action'] = {}
31
+ action = step['action']
32
+ action['arm_concat'] = tf.concat([eef_delta_pos, eef_ang, grip_open], axis=0)
33
+
34
+ action['terminate'] = step['is_terminal']
35
+ # Write the action format
36
+ action['format'] = tf.constant(
37
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
38
+
39
+
40
+ # Convert raw state to our state
41
+ state = step['observation']['state']
42
+ # Robot state, consists of [3x end effector pos, 3x robot rpy angles, 1x gripper position].
43
+ gripper_pos = state[:3]/1000 # change from mm to m
44
+ gripper_ang = state[3:6]
45
+ gripper_ang = euler_to_rotation_matrix(gripper_ang)
46
+ gripper_ang = rotation_matrix_to_ortho6d(gripper_ang)
47
+ gripper_open = state[6:7]/1000 * 11.54 # rescale to [0, 1]
48
+
49
+
50
+ # Concatenate the state
51
+ state = step['observation']
52
+ state['arm_concat'] = tf.concat([gripper_pos, gripper_ang, gripper_open], axis=0)
53
+
54
+ # Write the state format
55
+ state['format'] = tf.constant(
56
+ "eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5,gripper_joint_0_pos")
57
+
58
+ # Clean the task instruction
59
+ # Define the replacements (old, new) as a dictionary
60
+ replacements = {
61
+ '_': ' ',
62
+ '1f': ' ',
63
+ '4f': ' ',
64
+ '-': ' ',
65
+ '50': ' ',
66
+ '55': ' ',
67
+ '56': ' ',
68
+
69
+ }
70
+ instr = step['language_instruction']
71
+ instr = clean_task_instruction(instr, replacements)
72
+ step['observation']['natural_language_instruction'] = instr
73
+
74
+ return step
75
+
76
+
77
+ if __name__ == "__main__":
78
+ import tensorflow_datasets as tfds
79
+ from data.utils import dataset_to_path
80
+
81
+ DATASET_DIR = 'data/datasets/openx_embod'
82
+ DATASET_NAME = 'fractal20220817_data'
83
+ # Load the dataset
84
+ dataset = tfds.builder_from_directory(
85
+ builder_dir=dataset_to_path(
86
+ DATASET_NAME, DATASET_DIR))
87
+ dataset = dataset.as_dataset(split='all')
88
+
89
+ # Inspect the dataset
90
+ for episode in dataset:
91
+ for step in episode['steps']:
92
+ print(step)
data/preprocess_scripts/utokyo_xarm_bimanual_converted_externally_to_rlds.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import tensorflow as tf
3
+
4
+ from data.utils import clean_task_instruction, euler_to_rotation_matrix, rotation_matrix_to_ortho6d
5
+
6
+
7
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
8
+ """
9
+ Convert terminate action to a boolean, where True means terminate.
10
+ """
11
+ return tf.reduce_all(tf.equal(terminate_act, tf.constant([1, 0, 0], dtype=tf.int32)))
12
+
13
+
14
+ def process_step(step: dict) -> dict:
15
+ """
16
+ Unify the action format and clean the task instruction.
17
+
18
+ DO NOT use python list, use tf.TensorArray instead.
19
+ """
20
+ # Convert raw action to our action
21
+ action = step['action']
22
+
23
+ # TODO
24
+ # action : Tensor (14,) [3x EEF position (L), 3x EEF orientation yaw/pitch/roll (L), 1x gripper open/close position (L), 3x EEF position (R), 3x EEF orientation yaw/pitch/roll (R), 1x gripper open/close position (R)].
25
+
26
+ eef_pos_left = action[0:3]
27
+ eef_angle_left = tf.gather(action[3:6], [2, 1, 0])
28
+ eef_angle_left = euler_to_rotation_matrix(eef_angle_left)
29
+ eef_angle_left = rotation_matrix_to_ortho6d(eef_angle_left)
30
+ gripper_open_left = 1 - action[6:7]
31
+ eef_pos_right = action[7:10]
32
+ eef_angle_right = tf.gather(action[10:13], [2, 1, 0])
33
+ eef_angle_right = euler_to_rotation_matrix(eef_angle_right)
34
+ eef_angle_right = rotation_matrix_to_ortho6d(eef_angle_right)
35
+ gripper_open_right = 1 - action[13:14]
36
+
37
+ # Concatenate the action
38
+ step['action'] = {}
39
+ action = step['action']
40
+
41
+ # Concatenate the action
42
+ arm_action = tf.concat([eef_pos_left,eef_angle_left,gripper_open_left,eef_pos_right,eef_angle_right,gripper_open_right], axis=0)
43
+ action['arm_concat'] = arm_action
44
+ action['terminate'] = step['is_terminal']
45
+
46
+ # print("action len:", len(action['arm_concat']) + len(action['base_concat']))
47
+
48
+ action['format'] = tf.constant(
49
+ "left_eef_pos_x,left_eef_pos_y,left_eef_pos_z,left_eef_angle_0,left_eef_angle_1,left_eef_angle_2,left_eef_angle_3,left_eef_angle_4,left_eef_angle_5,left_gripper_open,right_eef_pos_x,right_eef_pos_y,right_eef_pos_z,right_eef_angle_0,right_eef_angle_1,right_eef_angle_2,right_eef_angle_3,right_eef_angle_4,right_eef_angle_5,right_gripper_open")
50
+
51
+ # action good for kuka same as example
52
+
53
+ # Convert raw state to our state
54
+ action = step['observation']['action_l']
55
+ # [3x EEF position, 3x EEF orientation yaw/pitch/roll, 1x gripper open/close position].
56
+ eef_pos_left = action[0:3]
57
+ eef_angle_left = tf.gather(action[3:6], [2, 1, 0])
58
+ eef_angle_left = euler_to_rotation_matrix(eef_angle_left)
59
+ eef_angle_left = rotation_matrix_to_ortho6d(eef_angle_left)
60
+ gripper_open_left = 1 - action[6:7]
61
+
62
+ action = step['observation']['action_r']
63
+ eef_pos_right = action[0:3]
64
+ eef_angle_right = tf.gather(action[3:6], [2, 1, 0])
65
+ eef_angle_right = euler_to_rotation_matrix(eef_angle_right)
66
+ eef_angle_right = rotation_matrix_to_ortho6d(eef_angle_right)
67
+ gripper_open_right = 1 - action[6:7]
68
+
69
+ # Write the state format TODO how to link 12 joint pos to 7 joint pos ??
70
+ state = step['observation']
71
+ # Concatenate the state
72
+ state['arm_concat'] = tf.concat([eef_pos_left,eef_angle_left,gripper_open_left,eef_pos_right,eef_angle_right,gripper_open_right], axis=0)
73
+ state['format'] = tf.constant(
74
+ "left_eef_pos_x,left_eef_pos_y,left_eef_pos_z,left_eef_angle_0,left_eef_angle_1,left_eef_angle_2,left_eef_angle_3,left_eef_angle_4,left_eef_angle_5,left_gripper_open,right_eef_pos_x,right_eef_pos_y,right_eef_pos_z,right_eef_angle_0,right_eef_angle_1,right_eef_angle_2,right_eef_angle_3,right_eef_angle_4,right_eef_angle_5,right_gripper_open")
75
+
76
+ # Clean the task instruction
77
+ # Define the replacements (old, new) as a dictionary
78
+ replacements = {
79
+ '_': ' ',
80
+ '1f': ' ',
81
+ '4f': ' ',
82
+ '-': ' ',
83
+ '50': ' ',
84
+ '55': ' ',
85
+ '56': ' ',
86
+
87
+ }
88
+ instr = step['language_instruction']
89
+ instr = clean_task_instruction(instr, replacements)
90
+ step['observation']['natural_language_instruction'] = instr
91
+
92
+ return step
93
+
94
+
95
+ if __name__ == "__main__":
96
+ import tensorflow_datasets as tfds
97
+ from data.utils import dataset_to_path
98
+
99
+ DATASET_DIR = 'data/datasets/openx_embod'
100
+ DATASET_NAME = 'utokyo_xarm_bimanual_converted_externally_to_rlds'
101
+ # Load the dataset
102
+ dataset = tfds.builder_from_directory(
103
+ builder_dir=dataset_to_path(
104
+ DATASET_NAME, DATASET_DIR))
105
+ dataset = dataset.as_dataset(split='all')
106
+
107
+ # with open('example.txt', 'w') as file:
108
+ # Inspect the dataset
109
+
110
+ episode_num = len(dataset)
111
+ print(f"episode_num: {episode_num}")
112
+ for episode in dataset.take(1):
113
+ # print("episode")
114
+ # print(list(episode.keys()))
115
+ for step in episode['steps']:
116
+ process_step(step)
117
+ break
data/preprocess_scripts/viola.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+ from data.utils import clean_task_instruction, euler_to_quaternion, rotation_matrix_to_ortho6d
4
+
5
+
6
+ def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
7
+ """
8
+ Convert terminate action to a boolean, where True means terminate.
9
+ """
10
+ return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
11
+
12
+
13
+ def process_step(step: dict) -> dict:
14
+ """
15
+ Unify the action format and clean the task instruction.
16
+
17
+ DO NOT use python list, use tf.TensorArray instead.
18
+ """
19
+ # Convert raw action to our action
20
+ action = step['action']
21
+ action['terminate'] = terminate_act_to_bool(action['terminate_episode'])
22
+
23
+ eef_delta_pos = action['world_vector']
24
+ eef_ang = action['rotation_delta']
25
+ eef_ang = euler_to_quaternion(eef_ang)
26
+ grip_open = tf.reshape(tf.where(action['gripper_closedness_action']<0,tf.constant(1.0),tf.constant(0.0)),(1,))
27
+
28
+ # No base found
29
+
30
+ # Concatenate the action
31
+ arm_action = tf.concat([eef_delta_pos, eef_ang, grip_open], axis=0)
32
+ action['arm_concat'] = arm_action
33
+
34
+ # Write the action format
35
+ action['format'] = tf.constant(
36
+ "eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
37
+
38
+ # Convert raw state to our state
39
+ state = step['observation']
40
+ joint_pos=state['joint_states']
41
+ grip_open=state['gripper_states'] * 12.905 # rescale to [0, 1]
42
+ state_ee=state['ee_states']
43
+ transform_matrix = tf.transpose(tf.reshape(state_ee, (4, 4)))
44
+ eef_pos = transform_matrix[:3, 3]
45
+ rotation_matrix = transform_matrix[:3, :3]
46
+ eef_ang = rotation_matrix_to_ortho6d(rotation_matrix)
47
+
48
+ # Concatenate the state
49
+ state['arm_concat'] = tf.concat([joint_pos,grip_open,eef_pos,eef_ang],axis=0)
50
+
51
+ # Write the state format
52
+ state['format'] = tf.constant(
53
+ "arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,gripper_open,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
54
+
55
+ # Clean the task instruction
56
+ # Define the replacements (old, new) as a dictionary
57
+ replacements = {
58
+ '_': ' ',
59
+ '1f': ' ',
60
+ '4f': ' ',
61
+ '-': ' ',
62
+ '50': ' ',
63
+ '55': ' ',
64
+ '56': ' ',
65
+
66
+ }
67
+ instr = step['observation']['natural_language_instruction']
68
+ instr = clean_task_instruction(instr, replacements)
69
+ step['observation']['natural_language_instruction'] = instr
70
+
71
+ return step
72
+
73
+
74
+ if __name__ == "__main__":
75
+ import tensorflow_datasets as tfds
76
+ from data.utils import dataset_to_path
77
+
78
+ DATASET_DIR = 'data/datasets/openx_embod'
79
+ DATASET_NAME = 'viola'
80
+ # Load the dataset
81
+ dataset = tfds.builder_from_directory(
82
+ builder_dir=dataset_to_path(
83
+ DATASET_NAME, DATASET_DIR))
84
+ dataset = dataset.as_dataset(split='all')
85
+
86
+ # Inspect the dataset
87
+ for episode in dataset:
88
+ for step in episode['steps']:
89
+ print(step)
data/producer.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import json
3
+ import os
4
+ import time
5
+ import argparse
6
+ import sys
7
+ import signal
8
+ import random
9
+ from multiprocessing import Process
10
+
11
+ import numpy as np
12
+ import tensorflow as tf
13
+ import yaml
14
+
15
+ from data.vla_dataset import VLADataset
16
+ from data.filelock import FileLock
17
+
18
+
19
+ # Producer does not need GPU
20
+ tf.config.set_visible_devices([], 'GPU')
21
+
22
+ # Read the config
23
+ with open('configs/base.yaml', 'r') as file:
24
+ config = yaml.safe_load(file)
25
+ # Load some constants from the config
26
+ BUF_PATH = config['dataset']['buf_path']
27
+ BUF_NUM_CHUNKS = config['dataset']['buf_num_chunks']
28
+ if BUF_NUM_CHUNKS < 1:
29
+ raise ValueError("Config `buf_num_chunks` must be at least 1.")
30
+ BUF_CHUNK_SIZE = config['dataset']['buf_chunk_size']
31
+ if BUF_CHUNK_SIZE < 1:
32
+ raise ValueError("Config `buf_chunk_size` must be at least 1.")
33
+
34
+
35
+ def get_dirty_item(chunk_dir):
36
+ """
37
+ Get indexes of dirty items in a chunk.
38
+ """
39
+ dirty_bit = read_dirty_bit(chunk_dir)
40
+ return np.where(dirty_bit)[0].tolist()
41
+
42
+
43
+ def get_clean_item(chunk_dir):
44
+ """
45
+ Get indexes of clean items in a chunk.
46
+ """
47
+ dirty_bit = read_dirty_bit(chunk_dir)
48
+ return np.where(1 - dirty_bit)[0].tolist()
49
+
50
+
51
+ def save_dirty_bit(chunk_dir, dirty_bit):
52
+ """
53
+ Save the dirty bit to the chunk directory.
54
+ """
55
+ time_stmp = time.time()
56
+ while time.time() - time_stmp < 10.0:
57
+ try:
58
+ file_path = os.path.join(chunk_dir, "dirty_bit")
59
+ lock = FileLock(file_path)
60
+ lock.acquire_write_lock()
61
+ with open(file_path, 'wb') as file:
62
+ file.write(dirty_bit.tobytes())
63
+ lock.release_lock()
64
+ return
65
+ except KeyboardInterrupt:
66
+ lock.release_lock()
67
+ raise KeyboardInterrupt
68
+ except BaseException:
69
+ lock.release_lock()
70
+ continue
71
+ # raise RuntimeError("Failed to save dirty bit.")
72
+ print("Failed to save dirty bit.")
73
+
74
+
75
+ def read_dirty_bit(chunk_dir):
76
+ """
77
+ Read the dirty bit from the chunk directory.
78
+ """
79
+ # If error occurs, retry
80
+ time_stmp = time.time()
81
+ while time.time() - time_stmp < 10.0:
82
+ try:
83
+ file_path = os.path.join(chunk_dir, "dirty_bit")
84
+ lock = FileLock(file_path)
85
+ lock.acquire_read_lock()
86
+ with open(file_path, 'rb') as file:
87
+ dirty_bit = np.frombuffer(file.read(), dtype=np.uint8).copy()
88
+ lock.release_lock()
89
+ assert len(dirty_bit) == BUF_CHUNK_SIZE
90
+ return dirty_bit
91
+ except KeyboardInterrupt:
92
+ lock.release_lock()
93
+ raise KeyboardInterrupt
94
+ except BaseException:
95
+ lock.release_lock()
96
+ continue
97
+ # If failed to read the dirty bit, return all ones for robustness
98
+ return np.ones(BUF_CHUNK_SIZE, dtype=np.uint8)
99
+
100
+
101
+ def save_sample(step_dict, chunk_dir, chunk_item_idx):
102
+ """
103
+ Save a sample to the chunk directory.
104
+ """
105
+ # Save the json content
106
+ time_stmp = time.time()
107
+ while time.time() - time_stmp < 10.0:
108
+ try:
109
+ locks = []
110
+ json_content = step_dict['json_content']
111
+ file_path = os.path.join(chunk_dir, f"json_content_{chunk_item_idx}.json")
112
+ lock = FileLock(file_path)
113
+ locks.append(lock)
114
+ lock.acquire_write_lock()
115
+ with open(file_path, 'w') as file:
116
+ json.dump(json_content, file, indent=4)
117
+ lock.release_lock()
118
+ # Save all other tensors in a npz
119
+ file_path = os.path.join(chunk_dir, f"sample_{chunk_item_idx}.npz")
120
+ lock = FileLock(file_path)
121
+ locks.append(lock)
122
+ lock.acquire_write_lock()
123
+ with open(file_path, 'wb') as file:
124
+ np.savez(
125
+ file,
126
+ step_id=step_dict['step_id'].numpy(),
127
+ state_chunk=step_dict['state_chunk'].numpy(),
128
+ state_chunk_time_mask=step_dict['state_chunk_time_mask'].numpy(),
129
+ action_chunk=step_dict['action_chunk'].numpy(),
130
+ action_chunk_time_mask=step_dict['action_chunk_time_mask'].numpy(),
131
+ state_vec_mask=step_dict['state_vec_mask'].numpy(),
132
+ past_frames_0=step_dict['past_frames_0'].numpy(),
133
+ past_frames_0_time_mask=step_dict['past_frames_0_time_mask'].numpy(),
134
+ past_frames_1=step_dict['past_frames_1'].numpy(),
135
+ past_frames_1_time_mask=step_dict['past_frames_1_time_mask'].numpy(),
136
+ past_frames_2=step_dict['past_frames_2'].numpy(),
137
+ past_frames_2_time_mask=step_dict['past_frames_2_time_mask'].numpy(),
138
+ past_frames_3=step_dict['past_frames_3'].numpy(),
139
+ past_frames_3_time_mask=step_dict['past_frames_3_time_mask'].numpy(),
140
+ state_std=step_dict['state_std'].numpy(),
141
+ state_mean=step_dict['state_mean'].numpy(),
142
+ state_norm=step_dict['state_norm'].numpy(),
143
+ )
144
+ lock.release_lock()
145
+ return
146
+ except KeyboardInterrupt:
147
+ for lock in locks:
148
+ lock.release_lock()
149
+ raise KeyboardInterrupt
150
+ except BaseException:
151
+ for lock in locks:
152
+ lock.release_lock()
153
+ continue
154
+ # raise RuntimeError("Failed to save sample.")
155
+ print("Failed to save sample.")
156
+
157
+
158
+ def run_producer(seed, num_workers, worker_id, fill_up, clean_dirty, dataset_type):
159
+ """
160
+ Run the producer.
161
+ The producer will first fill up the buffer with samples.
162
+ Then it will keep replacing dirty samples
163
+ (i.e., samples that have been read by the consumer)
164
+ with new samples.
165
+ """
166
+ vla_dataset = VLADataset(seed=seed, dataset_type=dataset_type)
167
+ chunk_start_idx = worker_id * BUF_NUM_CHUNKS // num_workers
168
+ chunk_end_idx = (worker_id + 1) * BUF_NUM_CHUNKS // num_workers
169
+ if fill_up:
170
+ print(f"Worker {worker_id}: Start filling up the buffer...")
171
+ elif clean_dirty:
172
+ # Only refresh the dirty bits
173
+ print(f"Worker {worker_id}: Start refreshing the dirty bits...")
174
+ for chunk_idx in range(chunk_start_idx, chunk_end_idx):
175
+ chunk_dir = os.path.join(BUF_PATH, f"chunk_{chunk_idx}")
176
+ dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
177
+ save_dirty_bit(chunk_dir, dirty_bit)
178
+ print(f"Worker {worker_id}: Refreshed the dirty bits.")
179
+
180
+ fill_chunk_idx = chunk_start_idx
181
+ fill_chunk_item_idx = 0
182
+ dirty_chunk_idx = chunk_start_idx
183
+ dirty_chunk_item_idxs = []
184
+ time_stmp = time.time()
185
+ for episode_steps in vla_dataset:
186
+ for step in episode_steps:
187
+ if fill_up and fill_chunk_idx < chunk_end_idx:
188
+ # Fill up the buffer
189
+ chunk_dir = os.path.join(BUF_PATH, f"chunk_{fill_chunk_idx}")
190
+ if fill_chunk_item_idx == 0:
191
+ # Create a new chunk
192
+ os.makedirs(chunk_dir, exist_ok=True)
193
+ # Write the dirty bit of size BUF_CHUNK_SIZE
194
+ dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
195
+ save_dirty_bit(chunk_dir, dirty_bit)
196
+
197
+ # Save the sample
198
+ save_sample(step, chunk_dir, fill_chunk_item_idx)
199
+
200
+ # print(f"Filled up chunk {fill_chunk_item_idx+1}/{BUF_CHUNK_SIZE} {fill_chunk_idx+1}/{BUF_NUM_CHUNKS}")
201
+ local_fill_chunk_idx = fill_chunk_idx - chunk_start_idx
202
+ local_num_chunks = chunk_end_idx - chunk_start_idx
203
+ if (local_fill_chunk_idx % 10 == 0 or local_fill_chunk_idx == local_num_chunks - 1) and fill_chunk_item_idx == 0:
204
+ print(f"Worker {worker_id}: Filled up chunk {local_fill_chunk_idx+1}/{local_num_chunks}")
205
+ fill_chunk_item_idx += 1
206
+ if fill_chunk_item_idx == BUF_CHUNK_SIZE:
207
+ fill_chunk_idx += 1
208
+ fill_chunk_item_idx = 0
209
+ if fill_chunk_idx == BUF_NUM_CHUNKS:
210
+ print(f"Worker {worker_id}: Buffer filled up. Start replacing dirty samples...")
211
+
212
+ else:
213
+ # Search for the dirty chunk to replace
214
+ while len(dirty_chunk_item_idxs) == 0:
215
+ dirty_chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}")
216
+ dirty_chunk_item_idxs = get_dirty_item(dirty_chunk_dir)
217
+ # Print the dirty ratio
218
+ if time.time() - time_stmp > 2.0:
219
+ dirty_ratio = len(dirty_chunk_item_idxs) / BUF_CHUNK_SIZE
220
+ print(f"Worker {worker_id}: Dirty Ratio for Chunk {dirty_chunk_idx}: {dirty_ratio:.2f}")
221
+ time_stmp = time.time()
222
+
223
+ if len(dirty_chunk_item_idxs) > 0:
224
+ # Lock the chunk
225
+ dirty_bit = np.ones(BUF_CHUNK_SIZE, dtype=np.uint8)
226
+ save_dirty_bit(dirty_chunk_dir, dirty_bit)
227
+
228
+ # Iterate over the chunks
229
+ dirty_chunk_idx += 1
230
+ if dirty_chunk_idx == chunk_end_idx:
231
+ dirty_chunk_idx = chunk_start_idx
232
+
233
+ # Replace the dirty item
234
+ dirty_item_idx = dirty_chunk_item_idxs.pop()
235
+ chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}")
236
+ # Save the sample
237
+ save_sample(step, chunk_dir, dirty_item_idx)
238
+
239
+ # If we have replaced all dirty items in the chunk
240
+ if len(dirty_chunk_item_idxs) == 0:
241
+ # Unlock the chunk
242
+ dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
243
+ save_dirty_bit(dirty_chunk_dir, dirty_bit)
244
+ print(f"Worker {worker_id}: Replaced dirty chunk {dirty_chunk_idx}.")
245
+
246
+
247
+ if __name__ == '__main__':
248
+ # Args: n_workers, fill_up
249
+ parser = argparse.ArgumentParser()
250
+ parser.add_argument('--n_workers', type=int, default=2, help="Number of parallel workers. It should be less than or equal to the number of chunks.")
251
+ parser.add_argument('--fill_up', action='store_true', help="Whether to fill up the buffer before replacing dirty samples.")
252
+ parser.add_argument('--clean_dirty', action='store_true', help="Whether to clean the dirty bits before replacing dirty samples. This option is ignored when `fill_up` is set.")
253
+ parser.add_argument('--seed', type=int, default=None, help="Random seed. If not set, the seed will be randomly generated.")
254
+ parser.add_argument('--dataset_type', type=str,
255
+ default="pretrain",
256
+ help="Whether to load the pretrain dataset or finetune dataset.")
257
+
258
+ # Run the producer
259
+ args = parser.parse_args()
260
+ if args.seed is not None:
261
+ print(f"Base seed: {args.seed}")
262
+ random.seed(args.seed)
263
+
264
+ processes = []
265
+ process_seeds = [random.randint(0, 2**32) for _ in range(args.n_workers)]
266
+ print(f"Process seeds: {process_seeds}")
267
+ def signal_handler(sig, frame):
268
+ print("Ctrl+C received. Terminating child processes...")
269
+ for p in processes:
270
+ p.terminate()
271
+ sys.exit(0)
272
+ signal.signal(signal.SIGINT, signal_handler)
273
+ for worker_id in range(args.n_workers):
274
+ p = Process(target=run_producer, args=(
275
+ process_seeds[worker_id], args.n_workers, worker_id, args.fill_up, args.clean_dirty, args.dataset_type))
276
+ p.start()
277
+ processes.append(p)
278
+
279
+ for p in processes:
280
+ p.join()
data/utils.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import tensorflow_graphics.geometry.transformation.euler as tf_euler
3
+ import tensorflow_graphics.geometry.transformation.quaternion as tf_quat
4
+ import tensorflow_graphics.geometry.transformation.rotation_matrix_3d as tf_rotmat
5
+
6
+
7
+ def dataset_to_path(dataset_name: str, dir_name: str) -> str:
8
+ """
9
+ Return the path to the dataset.
10
+ """
11
+ if dataset_name == 'robo_net' or \
12
+ dataset_name == 'cmu_playing_with_food' or \
13
+ dataset_name == 'droid':
14
+ version = '1.0.0'
15
+ elif dataset_name == 'language_table' or \
16
+ dataset_name == 'fmb' or \
17
+ dataset_name == 'dobbe':
18
+ version = '0.0.1'
19
+ elif dataset_name == 'nyu_door_opening_surprising_effectiveness':
20
+ version = ''
21
+ elif dataset_name == 'cmu_play_fusion':
22
+ version=''
23
+ elif dataset_name=='berkeley_gnm_recon':
24
+ version=''
25
+ elif dataset_name=='vla_benchmark_ee':
26
+ version = '1.0.0'
27
+ else:
28
+ version = '1.0.0'
29
+ return f'{dir_name}/{dataset_name}/{version}'
30
+
31
+
32
+ def clean_task_instruction(
33
+ task_instruction: tf.Tensor, replacements: dict) -> tf.Tensor:
34
+ """
35
+ Clean up the natural language task instruction.
36
+ """
37
+ # Create a function that applies all replacements
38
+ def apply_replacements(tensor):
39
+ for old, new in replacements.items():
40
+ tensor = tf.strings.regex_replace(tensor, old, new)
41
+ return tensor
42
+ # Apply the replacements and strip leading and trailing spaces
43
+ cleaned_task_instruction = apply_replacements(task_instruction)
44
+ cleaned_task_instruction = tf.strings.strip(cleaned_task_instruction)
45
+ return cleaned_task_instruction
46
+
47
+
48
+ def quaternion_to_euler(quaternion: tf.Tensor) -> tf.Tensor:
49
+ """
50
+ Convert a quaternion (x, y, z, w) to Euler angles (roll, pitch, yaw).
51
+ The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
52
+ """
53
+ # Normalize the quaternion
54
+ quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
55
+ return tf_euler.from_quaternion(quaternion)
56
+
57
+
58
+ def euler_to_quaternion(euler: tf.Tensor) -> tf.Tensor:
59
+ """
60
+ Convert Euler angles (roll, pitch, yaw) to a quaternion (x, y, z, w).
61
+ The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
62
+ """
63
+ quaternion = tf_quat.from_euler(euler)
64
+ return tf.nn.l2_normalize(quaternion, axis=-1)
65
+
66
+
67
+ def rotation_matrix_to_euler(matrix: tf.Tensor) -> tf.Tensor:
68
+ """
69
+ Convert a 3x3 rotation matrix to Euler angles (roll, pitch, yaw).
70
+ The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
71
+ """
72
+ return tf_euler.from_rotation_matrix(matrix)
73
+
74
+
75
+ def rotation_matrix_to_quaternion(matrix: tf.Tensor) -> tf.Tensor:
76
+ """
77
+ Convert a 3x3 rotation matrix to a quaternion (x, y, z, w).
78
+ """
79
+ quaternion = tf_quat.from_rotation_matrix(matrix)
80
+ return tf.nn.l2_normalize(quaternion, axis=-1)
81
+
82
+
83
+ def euler_to_rotation_matrix(euler: tf.Tensor) -> tf.Tensor:
84
+ """
85
+ Convert Euler angles (roll, pitch, yaw) to a 3x3 rotation matrix.
86
+ The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
87
+ """
88
+ return tf_rotmat.from_euler(euler)
89
+
90
+
91
+ def quaternion_to_rotation_matrix(quaternion: tf.Tensor) -> tf.Tensor:
92
+ """
93
+ Convert a quaternion (x, y, z, w) to a 3x3 rotation matrix.
94
+ """
95
+ # Normalize the quaternion
96
+ quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
97
+ return tf_rotmat.from_quaternion(quaternion)
98
+
99
+
100
+ def quaternion_to_rotation_matrix_wo_static_check(quaternion: tf.Tensor) -> tf.Tensor:
101
+ """
102
+ Convert a quaternion (x, y, z, w) to a 3x3 rotation matrix.
103
+ This function is used to make tensorflow happy.
104
+ """
105
+ # Normalize the quaternion
106
+ quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
107
+
108
+ x = quaternion[..., 0]
109
+ y = quaternion[..., 1]
110
+ z = quaternion[..., 2]
111
+ w = quaternion[..., 3]
112
+
113
+ tx = 2.0 * x
114
+ ty = 2.0 * y
115
+ tz = 2.0 * z
116
+ twx = tx * w
117
+ twy = ty * w
118
+ twz = tz * w
119
+ txx = tx * x
120
+ txy = ty * x
121
+ txz = tz * x
122
+ tyy = ty * y
123
+ tyz = tz * y
124
+ tzz = tz * z
125
+ matrix = tf.stack((1.0 - (tyy + tzz), txy - twz, txz + twy,
126
+ txy + twz, 1.0 - (txx + tzz), tyz - twx,
127
+ txz - twy, tyz + twx, 1.0 - (txx + tyy)),
128
+ axis=-1) # pyformat: disable
129
+ output_shape = tf.concat((tf.shape(input=quaternion)[:-1], (3, 3)), axis=-1)
130
+ return tf.reshape(matrix, shape=output_shape)
131
+
132
+
133
+ """
134
+ Below is a continuous 6D rotation representation adapted from
135
+ On the Continuity of Rotation Representations in Neural Networks
136
+ https://arxiv.org/pdf/1812.07035.pdf
137
+ https://github.com/papagina/RotationContinuity/blob/master/sanity_test/code/tools.py
138
+ """
139
+ def rotation_matrix_to_ortho6d(matrix: tf.Tensor) -> tf.Tensor:
140
+ """
141
+ The orhto6d represents the first two column vectors a1 and a2 of the
142
+ rotation matrix: [ | , |, | ]
143
+ [ a1, a2, a3]
144
+ [ | , |, | ]
145
+ Input: (A1, ..., An, 3, 3)
146
+ Output: (A1, ..., An, 6)
147
+ """
148
+ ortho6d = matrix[..., :, :2]
149
+ # Transpose the last two dimension
150
+ perm = list(range(len(ortho6d.shape)))
151
+ perm[-2], perm[-1] = perm[-1], perm[-2]
152
+ ortho6d = tf.transpose(ortho6d, perm)
153
+ # Flatten the last two dimension
154
+ ortho6d = tf.reshape(ortho6d, ortho6d.shape[:-2] + [6])
155
+ return ortho6d
156
+
157
+
158
+ def rotation_matrix_to_ortho6d_1d(matrix: tf.Tensor) -> tf.Tensor:
159
+ """
160
+ The orhto6d represents the first two column vectors a1 and a2 of the
161
+ rotation matrix: [ | , |, | ]
162
+ [ a1, a2, a3]
163
+ [ | , |, | ]
164
+ Input: (3, 3)
165
+ Output: (6,)
166
+ This function is used to make tensorflow happy.
167
+ """
168
+ ortho6d = matrix[:, :2]
169
+ # Transpose the last two dimension
170
+ ortho6d = tf.transpose(ortho6d)
171
+ # Flatten the last two dimension
172
+ ortho6d = tf.reshape(ortho6d, [6])
173
+ return ortho6d
174
+
175
+
176
+ def normalize_vector(v):
177
+ """
178
+ v: (..., N)
179
+ """
180
+ v_mag = tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True))
181
+ v_mag = tf.maximum(v_mag, 1e-8)
182
+ v_normalized = v / v_mag
183
+
184
+ return v_normalized
185
+
186
+
187
+ def cross_product(u, v):
188
+ """
189
+ u: (..., 3)
190
+ v: (..., 3)
191
+ u x v: (..., 3)
192
+ """
193
+ i = u[..., 1] * v[..., 2] - u[..., 2] * v[..., 1]
194
+ j = u[..., 2] * v[..., 0] - u[..., 0] * v[..., 2]
195
+ k = u[..., 0] * v[..., 1] - u[..., 1] * v[..., 0]
196
+ out = tf.stack([i, j, k], axis=-1)
197
+ return out
198
+
199
+
200
+ def ortho6d_to_rotation_matrix(ortho6d: tf.Tensor) -> tf.Tensor:
201
+ """
202
+ The orhto6d represents the first two column vectors a1 and a2 of the
203
+ rotation matrix: [ | , |, | ]
204
+ [ a1, a2, a3]
205
+ [ | , |, | ]
206
+ Input: (A1, ..., An, 6)
207
+ Output: (A1, ..., An, 3, 3)
208
+ """
209
+ x_raw = ortho6d[..., 0:3]
210
+ y_raw = ortho6d[..., 3:6]
211
+
212
+ x = normalize_vector(x_raw)
213
+ z = cross_product(x, y_raw)
214
+ z = normalize_vector(z)
215
+ y = cross_product(z, x)
216
+
217
+ # Stack x, y, z to form the matrix
218
+ matrix = tf.stack([x, y, z], axis=-1)
219
+ return matrix
220
+
221
+
222
+ def capitalize_and_period(instr: str) -> str:
223
+ """
224
+ Capitalize the first letter of a string and add a period to the end if it's not there.
225
+ """
226
+ if len(instr) > 0:
227
+ # if the first letter is not capital, make it so
228
+ if not instr[0].isupper():
229
+ # if the first letter is not capital, make it so
230
+ instr = instr[0].upper() + instr[1:]
231
+ # add period to the end if it's not there
232
+ if instr[-1] != '.':
233
+ # add period to the end if it's not there
234
+ instr = instr + '.'
235
+ return instr
data/vla_dataset.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ import tensorflow_datasets as tfds
7
+ import yaml
8
+
9
+ from data.episode_transform import process_episode, flatten_episode, \
10
+ flatten_episode_agilex, bgr_to_rgb
11
+ from data.utils import dataset_to_path
12
+ from data.preprocess_scripts import *
13
+
14
+ # Producer does not need GPU
15
+ tf.config.set_visible_devices([], 'GPU')
16
+
17
+ OPENX_EMBOD_DIR = 'data/datasets/openx_embod'
18
+
19
+ DATASET_NAMES_NOOPENX = [
20
+ "aloha_mobile",
21
+ "aloha_static",
22
+ "roboset",
23
+ "agilex",
24
+ "rh20t",
25
+ 'calvin',
26
+ "bridgev2"
27
+ ]
28
+
29
+ # Read the config
30
+ with open('configs/base.yaml', 'r') as file:
31
+ config = yaml.safe_load(file)
32
+ # Load some constants from the config
33
+ EPSD_LEN_THRESH_LOW = config['dataset']['epsd_len_thresh_low']
34
+ EPSD_LEN_THRESH_HIGH = config['dataset']['epsd_len_thresh_high']
35
+ # Read the image keys of each dataset
36
+ with open('configs/dataset_img_keys.json', 'r') as file:
37
+ IMAGE_KEYS = json.load(file)
38
+
39
+
40
+ class VLADataset:
41
+ """
42
+ This class is used to sample episodes from the embododiment dataset.
43
+ """
44
+ def __init__(self, seed, dataset_type, repeat=True):
45
+ '''
46
+ seed: the random seed
47
+ dataset_type: 'pretrain' or 'finetune', which dataset to load
48
+ repeat: whether to repeat to infinite length
49
+ '''
50
+ dataset_names_cfg = 'configs/pretrain_datasets.json' \
51
+ if dataset_type == "pretrain" else 'configs/finetune_datasets.json'
52
+ with open(dataset_names_cfg, 'r') as file:
53
+ DATASET_NAMES = json.load(file)
54
+ self.dataset_names = DATASET_NAMES
55
+ sample_weights_cfg = 'configs/pretrain_sample_weights.json' \
56
+ if dataset_type == "pretrain" else 'configs/finetune_sample_weights.json'
57
+ # Load the sample weights
58
+ with open(sample_weights_cfg, 'r') as file:
59
+ SAMPLE_WEIGHTS = json.load(file)
60
+ self.openx_dir = OPENX_EMBOD_DIR
61
+ self.epsd_len_thresh_low = EPSD_LEN_THRESH_LOW
62
+ self.epsd_len_thresh_high = EPSD_LEN_THRESH_HIGH
63
+ self.repeat = repeat
64
+
65
+ # Set the random seed
66
+ tf.random.set_seed(seed)
67
+ np.random.seed(seed)
68
+
69
+ # Weights of the each dataset in the collection to sample from
70
+ sample_weights = []
71
+
72
+ self.name2dataset = {}
73
+ for dataset_name in self.dataset_names:
74
+ if dataset_name in DATASET_NAMES_NOOPENX:
75
+ dataset = globals()[dataset_name].load_dataset(seed)
76
+ else:
77
+ dataset_path = dataset_to_path(dataset_name, self.openx_dir)
78
+ dataset = tfds.builder_from_directory(builder_dir=dataset_path)
79
+ dataset = dataset.as_dataset(split='all', shuffle_files=True)
80
+
81
+ # You can add filter for other datasets
82
+ if dataset_name == 'kuka':
83
+ dataset = dataset.filter(
84
+ lambda x: x['success'])
85
+ elif dataset_name == 'bc_z':
86
+ dataset = dataset.filter(
87
+ lambda x: tf.math.greater(
88
+ next(iter(x['steps']))['observation']['episode_success'], 0.5))
89
+ elif dataset_name == 'ucsd_pick_and_place_dataset_converted_externally_to_rlds':
90
+ dataset = dataset.filter(
91
+ lambda x: x['episode_metadata']['success'])
92
+ elif dataset_name == 'utokyo_xarm_bimanual_converted_externally_to_rlds':
93
+ # Only preserve the meaningful episodes
94
+ dataset = dataset.filter(
95
+ lambda x: tf.math.equal(
96
+ next(iter(x['steps']))['language_instruction'],
97
+ tf.constant('Unfold a wrinkled towel.')))
98
+
99
+ # Note: use cache() will cause the unexpected crash
100
+ # dataset = dataset.map().cache().shuffle().repeat()
101
+ print(dataset_name)
102
+ dataset = dataset\
103
+ .map(
104
+ lambda x: process_episode(x, dataset_name,
105
+ IMAGE_KEYS[dataset_name]['image_keys'],
106
+ IMAGE_KEYS[dataset_name]['image_mask'])
107
+ )
108
+
109
+ # Change BGR to RGB if needed
110
+ if dataset_name == 'fmb':
111
+ dataset = dataset.map(bgr_to_rgb)
112
+
113
+ if self.repeat:
114
+ dataset = dataset.repeat()
115
+ self.name2dataset[dataset_name] = iter(dataset)
116
+ print(SAMPLE_WEIGHTS)
117
+ sample_weights.append(SAMPLE_WEIGHTS[dataset_name])
118
+ # Normalize the sample weights
119
+ sample_weights = np.array(sample_weights)
120
+ self.sample_weights = sample_weights / np.sum(sample_weights)
121
+
122
+ def __iter__(self):
123
+ '''
124
+ Sample batches of episodes for an epoch.
125
+ '''
126
+ while True:
127
+ dataset_name = np.random.choice(self.dataset_names, p=self.sample_weights)
128
+ episode = next(self.name2dataset[dataset_name])
129
+ if dataset_name == "agilex":
130
+ episode_steps = flatten_episode_agilex(episode)
131
+ else:
132
+ episode_steps = flatten_episode(episode)
133
+ # Filter too short
134
+ if len(episode_steps) < self.epsd_len_thresh_low:
135
+ continue
136
+ # Randomly sample too long
137
+ if len(episode_steps) > self.epsd_len_thresh_high:
138
+ episode_steps = random.sample(episode_steps, self.epsd_len_thresh_high)
139
+
140
+ yield episode_steps
141
+
142
+
143
+ if __name__ == "__main__":
144
+ dataset = VLADataset(0, 'finetune')
145
+ for episode in dataset:
146
+ print(episode[0])
147
+ break
encode_lang.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import yaml
5
+
6
+ from models.multimodal_encoder.t5_encoder import T5Embedder
7
+
8
+
9
+ GPU = 0
10
+ MODEL_PATH = "google/t5-v1_1-xxl"
11
+ CONFIG_PATH = "configs/base.yaml"
12
+ SAVE_DIR = "lang_embed/"
13
+
14
+ # Modify this to your task name and instruction
15
+ TASK_NAME = "anubis_carrot_to_bag"
16
+ # INSTRUCTION = "take the towel off the kirby doll"
17
+ # INSTRUCTION = "insert the brush to the dustpan"
18
+ INSTRUCTION = "pick up the carrot and put into the bag"
19
+
20
+ # Note: if your GPU VRAM is less than 24GB,
21
+ # it is recommended to enable offloading by specifying an offload directory.
22
+ # OFFLOAD_DIR = '/home/jellyho/OFFLOAD' # Specify your offload directory here, ensuring the directory exists.
23
+
24
+ def main():
25
+ with open(CONFIG_PATH, "r") as fp:
26
+ config = yaml.safe_load(fp)
27
+
28
+ device = torch.device(f"cuda:{GPU}")
29
+ text_embedder = T5Embedder(
30
+ from_pretrained=MODEL_PATH,
31
+ model_max_length=config["dataset"]["tokenizer_max_length"],
32
+ device=device,
33
+ # use_offload_folder=OFFLOAD_DIR
34
+ )
35
+ tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
36
+
37
+ tokens = tokenizer(
38
+ INSTRUCTION, return_tensors="pt",
39
+ padding="longest",
40
+ truncation=True
41
+ )["input_ids"].to(device)
42
+
43
+ tokens = tokens.view(1, -1)
44
+ with torch.no_grad():
45
+ pred = text_encoder(tokens).last_hidden_state.detach().cpu()
46
+
47
+ save_path = os.path.join(SAVE_DIR, f"{TASK_NAME}.pt")
48
+ # We save the embeddings in a dictionary format
49
+ torch.save({
50
+ "name": TASK_NAME,
51
+ "instruction": INSTRUCTION,
52
+ "embeddings": pred
53
+ }, save_path
54
+ )
55
+
56
+ print(f'\"{INSTRUCTION}\" from \"{TASK_NAME}\" is encoded by \"{MODEL_PATH}\" into shape {pred.shape} and saved to \"{save_path}\"')
57
+
58
+
59
+ if __name__ == "__main__":
60
+ main()
finetune.sh ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
2
+ export NCCL_IB_DISABLE=0
3
+ export NCCL_SOCKET_IFNAME=eth1
4
+ export NCCL_DEBUG=INFO
5
+ export NCCL_NVLS_ENABLE=0
6
+ export MASTER_PORT=$2
7
+ export TEXT_ENCODER_NAME="google/t5-v1_1-xxl"
8
+ export VISION_ENCODER_NAME="google/siglip-so400m-patch14-384"
9
+ export OUTPUT_DIR="./checkpoints/$1"
10
+ export CFLAGS="-I/usr/include"
11
+ export LDFLAGS="-L/usr/lib/x86_64-linux-gnu"
12
+ export CUTLASS_PATH="/home/jellyho/cutlass"
13
+
14
+ export WANDB_PROJECT="robotics_diffusion_transformer"
15
+
16
+ if [ ! -d "$OUTPUT_DIR" ]; then
17
+ mkdir "$OUTPUT_DIR"
18
+ echo "Folder '$OUTPUT_DIR' created"
19
+ else
20
+ echo "Folder '$OUTPUT_DIR' already exists"
21
+ fi
22
+
23
+ # For run in a single node/machine
24
+ # accelerate launch main.py \
25
+ # --deepspeed="./configs/zero2.json" \
26
+ # ...
27
+ # --hostfile=hostfile.txt
28
+
29
+ accelerate launch --main_process_port $2 --num_processes 2 --num_machines 1 --mixed_precision bf16 main.py \
30
+ --deepspeed="./configs/zero2.json" \
31
+ --pretrained_model_name_or_path="robotics-diffusion-transformer/rdt-1b" \
32
+ --pretrained_text_encoder_name_or_path=$TEXT_ENCODER_NAME \
33
+ --pretrained_vision_encoder_name_or_path=$VISION_ENCODER_NAME \
34
+ --output_dir=$OUTPUT_DIR \
35
+ --train_batch_size=8 \
36
+ --sample_batch_size=8 \
37
+ --max_train_steps=50000 \
38
+ --checkpointing_period=5000 \
39
+ --sample_period=1000 \
40
+ --checkpoints_total_limit=10 \
41
+ --lr_scheduler="constant" \
42
+ --learning_rate=1e-4 \
43
+ --mixed_precision="bf16" \
44
+ --dataloader_num_workers=16 \
45
+ --image_aug \
46
+ --dataset_type="finetune" \
47
+ --gradient_accumulation_steps 1 \
48
+ --report_to=wandb \
49
+ --load_from_hdf5 \
50
+ --dataset_name $1 \
51
+ --precomp_lang_embed
52
+ # --resume_from_checkpoint="checkpoint-50000"
53
+
54
+ # Use this to resume training from some previous checkpoint
55
+ # --resume_from_checkpoint="checkpoint-36000" \
56
+ # Use this to load from saved lanuage instruction embeddings,
57
+ # instead of calculating it during training
finetune_maniskill.sh ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
2
+ export NCCL_IB_DISABLE=0
3
+ export NCCL_SOCKET_IFNAME=bond0
4
+ export NCCL_DEBUG=INFO
5
+ export NCCL_NVLS_ENABLE=0
6
+
7
+ export TEXT_ENCODER_NAME="google/t5-v1_1-xxl"
8
+ export VISION_ENCODER_NAME="google/siglip-so400m-patch14-384"
9
+ export OUTPUT_DIR="./checkpoints/rdt-finetune-1b-sim"
10
+ export CFLAGS="-I/usr/include"
11
+ export LDFLAGS="-L/usr/lib/x86_64-linux-gnu"
12
+ export CUTLASS_PATH="/data/lingxuan/cutlass"
13
+
14
+ export WANDB_PROJECT="robotic_diffusion_transformer"
15
+
16
+ if [ ! -d "$OUTPUT_DIR" ]; then
17
+ mkdir "$OUTPUT_DIR"
18
+ echo "Folder '$OUTPUT_DIR' created"
19
+ else
20
+ echo "Folder '$OUTPUT_DIR' already exists"
21
+ fi
22
+ # For run in a single node/machine
23
+ # accelerate launch main.py \
24
+ # --deepspeed="./configs/zero2.json" \
25
+ # ...
26
+
27
+ accelerate launch main.py \
28
+ --deepspeed="./configs/zero2.json" \
29
+ --pretrained_model_name_or_path="robotics-diffusion-transformer/rdt-1b" \
30
+ --pretrained_text_encoder_name_or_path=$TEXT_ENCODER_NAME \
31
+ --pretrained_vision_encoder_name_or_path=$VISION_ENCODER_NAME \
32
+ --output_dir=$OUTPUT_DIR \
33
+ --train_batch_size=24 \
34
+ --sample_batch_size=32 \
35
+ --max_train_steps=400000 \
36
+ --checkpointing_period=10000 \
37
+ --sample_period=500 \
38
+ --checkpoints_total_limit=40 \
39
+ --lr_scheduler="constant" \
40
+ --learning_rate=1e-4 \
41
+ --mixed_precision="bf16" \
42
+ --dataloader_num_workers=8 \
43
+ --image_aug \
44
+ --dataset_type="finetune" \
45
+ --state_noise_snr=40 \
46
+ --load_from_hdf5 \
47
+ --report_to=wandb
48
+
inference.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ python -m scripts.agilex_inference \
2
+ --use_actions_interpolation \
3
+ --pretrained_model_name_or_path="checkpoints/your_finetuned_ckpt.pt" \ # your finetuned checkpoint: e.g., checkpoints/rdt-finetune-1b/checkpoint-<STEP NUMBER>, checkpoints/rdt-finetune-1b/checkpoint-<STEP NUMBER>/pytorch_model/mp_rank_00_model_states.pt,
4
+ --lang_embeddings_path="outs/lang_embeddings/your_instr.pt" \
5
+ --ctrl_freq=25 # your control frequency
main.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from train.train import train
4
+
5
+ from accelerate.logging import get_logger
6
+
7
+
8
+ def parse_args(input_args=None):
9
+ parser = argparse.ArgumentParser(description="Main script for training RDT.")
10
+ parser.add_argument(
11
+ "--config_path",
12
+ type=str,
13
+ default="configs/base.yaml",
14
+ help="Path to the configuration file. Default is `configs/base.yaml`.",
15
+ )
16
+ parser.add_argument(
17
+ "--deepspeed",
18
+ type=str,
19
+ default=None,
20
+ help="Enable DeepSpeed and pass the path to its config file or an already initialized DeepSpeed config dictionary",
21
+ )
22
+ parser.add_argument(
23
+ "--pretrained_text_encoder_name_or_path",
24
+ type=str,
25
+ default=None,
26
+ help="Pretrained text encoder name or path if not the same as model_name",
27
+ )
28
+ parser.add_argument(
29
+ "--pretrained_vision_encoder_name_or_path",
30
+ type=str,
31
+ default=None,
32
+ help="Pretrained vision encoder name or path if not the same as model_name",
33
+ )
34
+
35
+ parser.add_argument(
36
+ "--output_dir",
37
+ type=str,
38
+ default="checkpoints",
39
+ help="The output directory where the model predictions and checkpoints will be written.",
40
+ )
41
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
42
+
43
+ parser.add_argument(
44
+ "--load_from_hdf5",
45
+ action="store_true",
46
+ default=False,
47
+ help=(
48
+ "Whether to load the dataset directly from HDF5 files. "
49
+ "If False, the dataset will be loaded using producer-consumer pattern, "
50
+ "where the producer reads TFRecords and saves them to buffer, and the consumer reads from buffer."
51
+ )
52
+ )
53
+ parser.add_argument(
54
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
55
+ )
56
+ parser.add_argument(
57
+ "--sample_batch_size", type=int, default=8, help="Batch size (per device) for the sampling dataloader."
58
+ )
59
+ parser.add_argument(
60
+ "--num_sample_batches", type=int, default=2, help="Number of batches to sample from the dataset."
61
+ )
62
+ parser.add_argument("--num_train_epochs", type=int, default=1)
63
+ parser.add_argument(
64
+ "--max_train_steps",
65
+ type=int,
66
+ default=None,
67
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
68
+ )
69
+ parser.add_argument(
70
+ "--checkpointing_period",
71
+ type=int,
72
+ default=500,
73
+ help=(
74
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
75
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
76
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
77
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
78
+ "instructions."
79
+ ),
80
+ )
81
+ parser.add_argument(
82
+ "--checkpoints_total_limit",
83
+ type=int,
84
+ default=None,
85
+ help=(
86
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
87
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
88
+ " for more details"
89
+ ),
90
+ )
91
+ parser.add_argument(
92
+ "--resume_from_checkpoint",
93
+ type=str,
94
+ default=None,
95
+ help=(
96
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
97
+ ' `--checkpointing_period`, or `"latest"` to automatically select the last available checkpoint.'
98
+ ),
99
+ )
100
+ parser.add_argument(
101
+ "--pretrained_model_name_or_path",
102
+ type=str,
103
+ default=None,
104
+ help=(
105
+ "Path or name of a pretrained checkpoint to load the model from.\n",
106
+ " This can be either:\n"
107
+ " - a string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co, e.g., `robotics-diffusion-transformer/rdt-1b`,\n"
108
+ " - a path to a *directory* containing model weights saved using [`~RDTRunner.save_pretrained`] method, e.g., `./my_model_directory/`.\n"
109
+ " - a path to model checkpoint (*.pt), .e.g, `my_model_directory/checkpoint-10000/pytorch_model/mp_rank_00_model_states.pt`"
110
+ " - `None` if you are randomly initializing model using configuration at `config_path`."
111
+ )
112
+ )
113
+ parser.add_argument(
114
+ "--gradient_accumulation_steps",
115
+ type=int,
116
+ default=1,
117
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
118
+ )
119
+ parser.add_argument(
120
+ "--gradient_checkpointing",
121
+ action="store_true",
122
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
123
+ )
124
+ parser.add_argument(
125
+ "--learning_rate",
126
+ type=float,
127
+ default=5e-6,
128
+ help="Initial learning rate (after the potential warmup period) to use.",
129
+ )
130
+ parser.add_argument(
131
+ "--cond_mask_prob",
132
+ type=float,
133
+ default=0.1,
134
+ help=(
135
+ "The probability to randomly mask the conditions (except states) during training. "
136
+ "If set to 0, the conditions are not masked."
137
+ ),
138
+ )
139
+ parser.add_argument(
140
+ "--cam_ext_mask_prob",
141
+ type=float,
142
+ default=-1.0,
143
+ help=(
144
+ "The probability to randomly mask the external camera image during training. "
145
+ "If set to < 0, the external camera image is masked with the probability of `cond_mask_prob`."
146
+ ),
147
+ )
148
+ parser.add_argument(
149
+ "--state_noise_snr",
150
+ type=float,
151
+ default=None,
152
+ help=(
153
+ "The signal-to-noise ratio (SNR, unit: dB) for adding noise to the states. "
154
+ "Default is None, which means no noise is added."
155
+ ),
156
+ )
157
+ parser.add_argument(
158
+ "--image_aug",
159
+ action="store_true",
160
+ default=False,
161
+ help="Whether or not to apply image augmentation (ColorJitter, blur, noise, etc) to the input images.",
162
+ )
163
+ parser.add_argument(
164
+ "--precomp_lang_embed",
165
+ action="store_true",
166
+ default=False,
167
+ help="Whether or not to use precomputed language embeddings.",
168
+ )
169
+ parser.add_argument(
170
+ "--scale_lr",
171
+ action="store_true",
172
+ default=False,
173
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
174
+ )
175
+ parser.add_argument(
176
+ "--lr_scheduler",
177
+ type=str,
178
+ default="constant",
179
+ help=(
180
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
181
+ ' "constant", "constant_with_warmup"]'
182
+ ),
183
+ )
184
+ parser.add_argument(
185
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
186
+ )
187
+ parser.add_argument(
188
+ "--lr_num_cycles",
189
+ type=int,
190
+ default=1,
191
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
192
+ )
193
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
194
+ parser.add_argument(
195
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
196
+ )
197
+ parser.add_argument(
198
+ "--dataloader_num_workers",
199
+ type=int,
200
+ default=0,
201
+ help=(
202
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
203
+ ),
204
+ )
205
+ parser.add_argument("--alpha", type=float, default=0.9, help="The moving average coefficient for each dataset's loss.")
206
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
207
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
208
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
209
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
210
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
211
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
212
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
213
+ parser.add_argument(
214
+ "--hub_model_id",
215
+ type=str,
216
+ default=None,
217
+ help="The name of the repository to keep in sync with the local `output_dir`.",
218
+ )
219
+ parser.add_argument(
220
+ "--logging_dir",
221
+ type=str,
222
+ default="logs",
223
+ help=(
224
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
225
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
226
+ ),
227
+ )
228
+ parser.add_argument(
229
+ "--allow_tf32",
230
+ action="store_true",
231
+ help=(
232
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
233
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
234
+ ),
235
+ )
236
+ parser.add_argument(
237
+ "--report_to",
238
+ type=str,
239
+ default="tensorboard",
240
+ help=(
241
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
242
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
243
+ ),
244
+ )
245
+ parser.add_argument(
246
+ "--sample_period",
247
+ type=int,
248
+ default=-1,
249
+ help=(
250
+ "Run sampling every X steps. During the sampling phase, the model will sample a trajectory"
251
+ " and report the error between the sampled trajectory and groud-truth trajectory"
252
+ " in the training batch."
253
+ ),
254
+ )
255
+ parser.add_argument(
256
+ "--mixed_precision",
257
+ type=str,
258
+ default=None,
259
+ choices=["no", "fp16", "bf16"],
260
+ help=(
261
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
262
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
263
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
264
+ ),
265
+ )
266
+
267
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
268
+ parser.add_argument(
269
+ "--set_grads_to_none",
270
+ action="store_true",
271
+ help=(
272
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
273
+ " behaviors, so disable this argument if it causes any problems. More info:"
274
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
275
+ ),
276
+ )
277
+
278
+ parser.add_argument('--dataset_type',
279
+ type=str,
280
+ default="pretrain",
281
+ required=False,
282
+ help="Whether to load the pretrain dataset or finetune dataset."
283
+ )
284
+ parser.add_argument('--dataset_name', type=str)
285
+
286
+ if input_args is not None:
287
+ args = parser.parse_args(input_args)
288
+ else:
289
+ args = parser.parse_args()
290
+
291
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
292
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
293
+ args.local_rank = env_local_rank
294
+
295
+ return args
296
+
297
+
298
+ if __name__ == "__main__":
299
+ logger = get_logger(__name__)
300
+ args = parse_args()
301
+ train(args, logger)
models/ema_model.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Reference: DiffusionPolicy [https://github.com/real-stanford/diffusion_policy]
2
+
3
+ import torch
4
+ from torch.nn.modules.batchnorm import _BatchNorm
5
+
6
+
7
+ class EMAModel:
8
+ """
9
+ Exponential Moving Average of models weights
10
+ """
11
+ def __init__(
12
+ self,
13
+ model,
14
+ update_after_step=0,
15
+ inv_gamma=1.0,
16
+ power=2 / 3,
17
+ min_value=0.0,
18
+ max_value=0.9999
19
+ ):
20
+ """
21
+ @crowsonkb's notes on EMA Warmup:
22
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
23
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
24
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
25
+ at 215.4k steps).
26
+ Args:
27
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
28
+ power (float): Exponential factor of EMA warmup. Default: 2/3.
29
+ min_value (float): The minimum EMA decay rate. Default: 0.
30
+ """
31
+
32
+ self.averaged_model = model
33
+ self.averaged_model.eval()
34
+ self.averaged_model.requires_grad_(False)
35
+
36
+ self.update_after_step = update_after_step
37
+ self.inv_gamma = inv_gamma
38
+ self.power = power
39
+ self.min_value = min_value
40
+ self.max_value = max_value
41
+
42
+ self.decay = 0.0
43
+ self.optimization_step = 0
44
+
45
+ def get_decay(self, optimization_step):
46
+ """
47
+ Compute the decay factor for the exponential moving average.
48
+ """
49
+ step = max(0, optimization_step - self.update_after_step - 1)
50
+ value = 1 - (1 + step / self.inv_gamma) ** -self.power
51
+
52
+ if step <= 0:
53
+ return 0.0
54
+
55
+ return max(self.min_value, min(value, self.max_value))
56
+
57
+ @torch.no_grad()
58
+ def step(self, new_model):
59
+ self.decay = self.get_decay(self.optimization_step)
60
+
61
+ # old_all_dataptrs = set()
62
+ # for param in new_model.parameters():
63
+ # data_ptr = param.data_ptr()
64
+ # if data_ptr != 0:
65
+ # old_all_dataptrs.add(data_ptr)
66
+
67
+ all_dataptrs = set()
68
+ for module, ema_module in zip(new_model.modules(), self.averaged_model.modules()):
69
+ for param, ema_param in zip(module.parameters(recurse=False), ema_module.parameters(recurse=False)):
70
+ # iterative over immediate parameters only.
71
+ if isinstance(param, dict):
72
+ raise RuntimeError('Dict parameter not supported')
73
+
74
+ # data_ptr = param.data_ptr()
75
+ # if data_ptr != 0:
76
+ # all_dataptrs.add(data_ptr)
77
+
78
+ if isinstance(module, _BatchNorm):
79
+ # skip batchnorms
80
+ ema_param.copy_(param.to(dtype=ema_param.dtype).data)
81
+ elif not param.requires_grad:
82
+ ema_param.copy_(param.to(dtype=ema_param.dtype).data)
83
+ else:
84
+ ema_param.mul_(self.decay)
85
+ ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
86
+
87
+ # verify that iterating over module and then parameters is identical to parameters recursively.
88
+ # assert old_all_dataptrs == all_dataptrs
89
+ self.optimization_step += 1