Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +184 -0
- LICENSE +21 -0
- README.md +357 -0
- Untitled.ipynb +86 -0
- data/.gitignore +2 -0
- data/compute_dataset_stat.py +240 -0
- data/compute_dataset_stat_hdf5.py +100 -0
- data/episode_transform.py +406 -0
- data/filelock.py +24 -0
- data/hdf5_maniskill_dataset.py +243 -0
- data/hdf5_vla_dataset.py +533 -0
- data/preprocess.py +323 -0
- data/preprocess_scripts/__init__.py +73 -0
- data/preprocess_scripts/aloha_shoes_table.py +55 -0
- data/preprocess_scripts/austin_buds_dataset_converted_externally_to_rlds.py +82 -0
- data/preprocess_scripts/berkeley_autolab_ur5.py +95 -0
- data/preprocess_scripts/berkeley_cable_routing.py +73 -0
- data/preprocess_scripts/berkeley_gnm_sac_son.py +78 -0
- data/preprocess_scripts/berkeley_rpt_converted_externally_to_rlds.py +84 -0
- data/preprocess_scripts/calvin.py +176 -0
- data/preprocess_scripts/cmu_franka_exploration_dataset_converted_externally_to_rlds.py +75 -0
- data/preprocess_scripts/cmu_play_fusion.py +82 -0
- data/preprocess_scripts/cmu_stretch.py +84 -0
- data/preprocess_scripts/droid.py +78 -0
- data/preprocess_scripts/fractal20220817_data.py +92 -0
- data/preprocess_scripts/iamlab_cmu_pickup_insert_converted_externally_to_rlds.py +80 -0
- data/preprocess_scripts/libero_goal_no_noops.py +82 -0
- data/preprocess_scripts/libero_spatial_no_noops.py +82 -0
- data/preprocess_scripts/nyu_rot_dataset_converted_externally_to_rlds.py +82 -0
- data/preprocess_scripts/robo_net.py +71 -0
- data/preprocess_scripts/robomimic_lift_ph.py +97 -0
- data/preprocess_scripts/robomimic_square_ph.py +97 -0
- data/preprocess_scripts/roboset.py +367 -0
- data/preprocess_scripts/roboturk.py +77 -0
- data/preprocess_scripts/roboturk_real_objectsearch.py +217 -0
- data/preprocess_scripts/roboturk_real_towercreation.py +223 -0
- data/preprocess_scripts/stanford_hydra_dataset_converted_externally_to_rlds.py +94 -0
- data/preprocess_scripts/tokyo_u_lsmo_converted_externally_to_rlds.py +90 -0
- data/preprocess_scripts/utokyo_pr2_opening_fridge_converted_externally_to_rlds.py +92 -0
- data/preprocess_scripts/utokyo_xarm_bimanual_converted_externally_to_rlds.py +117 -0
- data/preprocess_scripts/viola.py +89 -0
- data/producer.py +280 -0
- data/utils.py +235 -0
- data/vla_dataset.py +147 -0
- encode_lang.py +60 -0
- finetune.sh +57 -0
- finetune_maniskill.sh +48 -0
- inference.sh +5 -0
- main.py +301 -0
- models/ema_model.py +89 -0
.gitignore
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
161 |
+
|
162 |
+
# Some encoder paths
|
163 |
+
facebook/
|
164 |
+
openai/
|
165 |
+
google/
|
166 |
+
|
167 |
+
# Log
|
168 |
+
logs/
|
169 |
+
|
170 |
+
# Output
|
171 |
+
outs/
|
172 |
+
|
173 |
+
# Checkpoints
|
174 |
+
checkpoints/
|
175 |
+
|
176 |
+
# VSC
|
177 |
+
.vscode/
|
178 |
+
|
179 |
+
# Wandb
|
180 |
+
wandb/
|
181 |
+
|
182 |
+
# Distributed leaning
|
183 |
+
hostfile.txt
|
184 |
+
.deepspeed_env
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 TSAIL group
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RDT-1B: a Diffusion Foundation Model for Bimanual Manipulation
|
2 |
+
|
3 |
+
### 📝[Paper](https://arxiv.org/pdf/2410.07864) | 🌍[Project Page](https://rdt-robotics.github.io/rdt-robotics/) | 🤗[Model](https://huggingface.co/robotics-diffusion-transformer/rdt-1b) | 🛢️[Data](https://huggingface.co/datasets/robotics-diffusion-transformer/rdt-ft-data)
|
4 |
+
|
5 |
+

|
6 |
+
|
7 |
+
RDT-1B is a **1B**-parameter (*largest* to date) imitation learning **Diffusion Transformer** pre-trained on **1M+** (*largest* to date) multi-robot episodes. Given language instruction and RGB images of up to three views, RDT can predict the next $64$ robot actions. RDT is inherently compatible with **almost all kinds of modern mobile manipulators**, from single-arm to dual-arm, joint to EEF, position to velocity, and even with wheeled locomotion.
|
8 |
+
|
9 |
+
We have fine-tuned RDT on **6K+** (one of the *largest*) self-collected bimanual episodes and deployed it on the ALOHA **dual-arm** robot. It has achieved state-of-the-art performance in terms of dexterity, zero-shot generalizability, and few-shot learning. You can find Demo videos on our [project page](https://rdt-robotics.github.io/rdt-robotics/).
|
10 |
+
|
11 |
+
This repo is an official PyTorch implementation of RDT, containing:
|
12 |
+
|
13 |
+
- 🛠️Model [implementation](models/rdt_runner.py) of RDT
|
14 |
+
- 🤗1M-step [checkpoint](https://huggingface.co/robotics-diffusion-transformer/rdt-1b) of RDT-1B pre-trained on multi-robot data
|
15 |
+
- 🤗500K-step [checkpoint](https://huggingface.co/robotics-diffusion-transformer/rdt-170m) of RDT-170M (RDT(small) in [ablation](https://arxiv.org/pdf/2410.07864))
|
16 |
+
- 📈Training and sampling [scripts](train/train.py) (with DeepSpeed)
|
17 |
+
- 🤖An [example](scripts/agilex_inference.py) of real-robot deployment
|
18 |
+
- 🕹️Simulation benchmark from [Maniskill](https://github.com/haosulab/ManiSkill) environment
|
19 |
+
|
20 |
+
The following guides include the [installation](#installation), [fine-tuning](#fine-tuning-on-your-own-dataset), and [deployment](#deployment-on-real-robots). Please refer to [pre-training](docs/pretrain.md) for a detailed list of pre-training datasets and a pre-training guide.
|
21 |
+
|
22 |
+
## 📰 News
|
23 |
+
- [2024/12/17] 🔥 [Scripts](#simulation-benchmark) for evaluating RDT in Maniskill Simulation Benchmark is released!
|
24 |
+
- [2024/10/23] 🔥 **RDT-170M** (Smaller) model is released, a more VRAM-friendly solution 🚀💻.
|
25 |
+
|
26 |
+
## Installation
|
27 |
+
|
28 |
+
1. Clone this repo and install prerequisites:
|
29 |
+
|
30 |
+
```bash
|
31 |
+
# Clone this repo
|
32 |
+
git clone [email protected]:thu-ml/RoboticsDiffusionTransformer.git
|
33 |
+
cd RoboticsDiffusionTransformer
|
34 |
+
|
35 |
+
# Create a Conda environment
|
36 |
+
conda create -n rdt python=3.10.0
|
37 |
+
conda activate rdt
|
38 |
+
|
39 |
+
# Install pytorch
|
40 |
+
# Look up https://pytorch.org/get-started/previous-versions/ with your cuda version for a correct command
|
41 |
+
pip install torch==2.1.0 torchvision==0.16.0 --index-url https://download.pytorch.org/whl/cu121
|
42 |
+
|
43 |
+
# Install packaging
|
44 |
+
pip install packaging==24.0
|
45 |
+
|
46 |
+
# Install flash-attn
|
47 |
+
pip install flash-attn --no-build-isolation
|
48 |
+
|
49 |
+
# Install other prequisites
|
50 |
+
pip install -r requirements.txt
|
51 |
+
```
|
52 |
+
|
53 |
+
2. Download off-the-shelf multi-modal encoders:
|
54 |
+
|
55 |
+
You can download the encoders from the following links:
|
56 |
+
|
57 |
+
- `t5-v1_1-xxl`: [link](https://huggingface.co/google/t5-v1_1-xxl/tree/main)🤗
|
58 |
+
- `siglip`: [link](https://huggingface.co/google/siglip-so400m-patch14-384)🤗
|
59 |
+
|
60 |
+
And link the encoders to the repo directory:
|
61 |
+
|
62 |
+
```bash
|
63 |
+
# Under the root directory of this repo
|
64 |
+
mkdir -p google
|
65 |
+
|
66 |
+
# Link the downloaded encoders to this repo
|
67 |
+
ln -s /path/to/t5-v1_1-xxl google/t5-v1_1-xxl
|
68 |
+
ln -s /path/to/siglip-so400m-patch14-384 google/siglip-so400m-patch14-384
|
69 |
+
```
|
70 |
+
3. Fill the missing argument in [this file](configs/base.yaml#L22):
|
71 |
+
|
72 |
+
Note that this buffer will only be used during pre-training. See [this doc](docs/pretrain.md) for more details.
|
73 |
+
```
|
74 |
+
# ...
|
75 |
+
|
76 |
+
dataset:
|
77 |
+
# ...
|
78 |
+
# ADD YOUR buf_path: the path to the buffer (at least 400GB)
|
79 |
+
buf_path: /path/to/buffer
|
80 |
+
# ...
|
81 |
+
```
|
82 |
+
|
83 |
+
## Fine-Tuning on Your Own Dataset
|
84 |
+
|
85 |
+
If your fine-tuning dataset is in the [Open X-Embodiment](https://robotics-transformer-x.github.io/) or the collection of our pre-training datasets (see [this doc](docs/pretrain.md#download-and-prepare-datasets)), you can also fine-tune RDT through the pre-trained pipeline. You need to remove other redundant datasets in the parameters. We refer to [this guide](docs/pretrain.md) (pre-training).
|
86 |
+
|
87 |
+
1. Prepare your dataset:
|
88 |
+
|
89 |
+
You need to download your dataset to the disk and give it a name `my_cool_dataset`.
|
90 |
+
|
91 |
+
Then, you can link your dataset to the repo directory:
|
92 |
+
|
93 |
+
```bash
|
94 |
+
# Under the root directory of this repo
|
95 |
+
cd data
|
96 |
+
mkdir -p datasets
|
97 |
+
|
98 |
+
# Link the downloaded dataset to this repo
|
99 |
+
ln -s /path/to/my_cool_dataset datasets/my_cool_dataset
|
100 |
+
```
|
101 |
+
|
102 |
+
2. Implement the dataset loader:
|
103 |
+
|
104 |
+
You need to:
|
105 |
+
|
106 |
+
1. Register the configuration of `my_cool_dataset`:
|
107 |
+
|
108 |
+
Append the control frequency of `my_cool_dataset` in [this file](configs/dataset_control_freq.json). Write the name of `my_cool_dataset` in [this file](configs/finetune_datasets.json) and [this file](configs/finetune_sample_weights.json), where the value of the sampling weight doesn't matter since you only have one dataset. In these two files, we leave a placeholder of `agilex`; you can simply replace it with `my_cool_dataset`.
|
109 |
+
|
110 |
+
2. Re-Implement the class of `HDF5VLADataset`:
|
111 |
+
|
112 |
+
You can find this class in [this file](data/hdf5_vla_dataset.py). In this file, we provide an example of loading the fine-tuning dataset used in our paper (see [this link](https://huggingface.co/datasets/robotics-diffusion-transformer/rdt-ft-data)).
|
113 |
+
|
114 |
+
To adapt it to your dataset, you need to: (a) modify the `HDF5_DIR` (directory to `my_cool_dataset`) and `DATASET_NAME` (should be `"my_cool_dataset"`) in L21 and L22; (b) Implement the two functions of `parse_hdf5_file()` and `parse_hdf5_file_state_only()`. Please take a look at the original file for detailed comments and examples.
|
115 |
+
|
116 |
+
Note 1: Despite its name, you don't necessarily need to use HDF5 to store your data. Just make sure that the class is correctly implemented.
|
117 |
+
|
118 |
+
Note 2: During implementation, you may need to fill your robot action into the unified action vector (L180-194). Please refer to [this file](configs/state_vec.py) for an explanation of each element in the unified vector. We have reserved enough slots for each physical quantity. For example, we have reserved ten slots for joint angles. If your robot arm has six degrees of freedom, you only need to fill in the first six.
|
119 |
+
|
120 |
+
**IMPORTANT 1:** If your robot is single-arm, please fill its action into the *right-arm* portion of the unified action vector, aligning with our pre-training datasets.
|
121 |
+
|
122 |
+
**IMPORTANT 2:** We use [6D representation](https://arxiv.org/pdf/1812.07035) for EEF rotation. If your action space contains EEF rotation (angle or quaternion), please refer to [this file](docs/test_6drot.py) for conversion. We note that this mapping is not reversible. Different Euler angles may be equivalent and correspond to the same 6D representation.
|
123 |
+
|
124 |
+
**IMPORTANT 3:** No physical quantities (except the gripper width) are normalized during pre-training. This can preserve each physical quantity's meaning, thereby promoting generalization across robots. Therefore, we encourage you not to normalize any physical quantities but to choose appropriate units for them. Generally, we use the International System of Units, which ensures that most values fall within [-1,1]. As an exception, we perform min-max normalization on the gripper width to [0,1].
|
125 |
+
|
126 |
+
**IMPORTANT 4:** If you use RTX 4090 (or lower), the GPU memory may be too low to load the `t5-v1_1-xxl` encoder. Instead, we recommend you precompute the language embeddings (see [this file](scripts/encode_lang_batch.py) for an example script) and load them during training. In this way, you need to specify the path to the embeddings in the `HDF5VLADataset` (see L148) rather than the natural language.
|
127 |
+
|
128 |
+
3. Compute the dataset statistics information for `my_cool_dataset`:
|
129 |
+
|
130 |
+
```bash
|
131 |
+
# Under the root directory of this repo
|
132 |
+
# Use -h to see the full usage
|
133 |
+
python -m data.compute_dataset_stat_hdf5
|
134 |
+
```
|
135 |
+
|
136 |
+
3. Start fine-tuning:
|
137 |
+
|
138 |
+
Configurations relevant to model architecture and data processing are in [this file](configs/base.yaml). Normally, you do not need to modify these configurations; otherwise, it will cause errors in loading the pre-training checkpoint. Configurations relevant to training are passed through *Command Line Arguments*. Use `python main.py -h ` to see the descriptions. We provide an example of a fine-tuning script in [this file](finetune.sh) (`finetune.sh`). You may need to modify some of the parameters in this file, such as `CUTLASS_PATH` and `WANDB_PROJECT`.
|
139 |
+
|
140 |
+
Use this to start fine-tuning:
|
141 |
+
|
142 |
+
```bash
|
143 |
+
source finetune.sh
|
144 |
+
```
|
145 |
+
|
146 |
+
with `finetune.sh` detailed as below:
|
147 |
+
|
148 |
+
```bash
|
149 |
+
deepspeed --hostfile=hostfile.txt main.py \
|
150 |
+
--deepspeed="./configs/zero2.json" \ # If you want to use DeepSpeed, which is strongly recommended
|
151 |
+
--pretrained_model_name_or_path=<MODEL ID | DIRECTORY OF MODEL WEIGHTS | PATH TO MODEL CHECKPOINT> \
|
152 |
+
--pretrained_text_encoder_name_or_path=<MODEL ID | PATH TO MODEL DIRECTORY > \ # e.g., google/t5-v1_1-xxl
|
153 |
+
--pretrained_vision_encoder_name_or_path=<MODEL ID | PATH TO MODEL DIRECTORY> \ # e.g., google/siglip-so400m-patch14-384
|
154 |
+
--output_dir=<DIRECTORY to SAVE CHECKPOINTS> \ # e.g., checkpoints/rdt-1b-agilex
|
155 |
+
--train_batch_size=32 \
|
156 |
+
--sample_batch_size=64 \ # batch size for diffusion sampling in validation
|
157 |
+
--max_train_steps=200000 \
|
158 |
+
--checkpointing_period=1000 \
|
159 |
+
--sample_period=500 \ # sample period for validation
|
160 |
+
--checkpoints_total_limit=40 \
|
161 |
+
--lr_scheduler="constant" \
|
162 |
+
--learning_rate=1e-4 \
|
163 |
+
--mixed_precision="bf16" \ # If you want to use mixed precision, bf16 is recommended
|
164 |
+
--dataloader_num_workers=8 \
|
165 |
+
--image_aug \ # If you want to use image augmentation
|
166 |
+
--dataset_type="finetune" \
|
167 |
+
--state_noise_snr=40 \ # If you want to add noise to the state
|
168 |
+
--load_from_hdf5 \ # If you use HDF5 to store your data
|
169 |
+
--report_to=wandb
|
170 |
+
```
|
171 |
+
|
172 |
+
**IMPORTANT**: If you have already chosen to precompute the language embeddings, please specify `--precomp_lang_embed` in the `finetune.sh`.
|
173 |
+
|
174 |
+
Note 1: `pretrained_model_name_or_path` can one of:
|
175 |
+
|
176 |
+
- a string, the *model id* of a pre-trained model hosted inside a model repo on HuggingFace. Please fill with `"robotics-diffusion-transformer/rdt-1b"`, which is the officially-released [RDT-1B model](https://huggingface.co/robotics-diffusion-transformer/rdt-1b)🤗 at HuggingFace. (recommended)
|
177 |
+
- a string, the path to a *directory* containing the manually downloaded model weights from HuggingFace, e.g., `"/path/to/rdt-1b"`. You should first manually download the `rdt-1b` directory from this [link](https://huggingface.co/robotics-diffusion-transformer/rdt-1b)🤗.
|
178 |
+
- a string, the path to a *directory* containing model weights saved using [`~RDTRunner.save_pretrained`] method. This can be either:
|
179 |
+
- `"checkpoints/rdt-pretrain-1b/checkpoint-<STEP NUMBER>"`: This is the path to the checkpoint saved in the `<STEP NUMBE>` iteration during pre-training. Refer to [this file](docs/pretrain.md) for a tutorial on how to start your own pre-training.
|
180 |
+
- `"checkpoints/rdt-pretrain-1b"`: If the pre-training completes normally without any exception, you can specify this path to load the last checkpoint.
|
181 |
+
- a string, the path to model checkpoint (`*.pt`) saved by DeepSpeed, e.g., `"checkpoints/rdt-pretrain-1b/checkpoint-<STEP NUMBER>/pytorch_model/mp_rank_00_model_states.pt"` (verified)
|
182 |
+
- `None` if you want to randomly initialize the model using configuration at `config_path`.
|
183 |
+
|
184 |
+
Note 2: You can monitor the training process by observing `loss` (through a long window moving average) and `overall_avg_sample_mse` in [Wandb](https://wandb.ai/site) or [TensorBoard](https://www.tensorflow.org/tensorboard). We empirically found that the lower the `overall_avg_sample_mse`, the better the model performs. Usually, fine-tuning is over when this value converges.
|
185 |
+
|
186 |
+
Note 3: If the training oscillates, you can increase the batch size by adding more GPUs or setting a larger `--gradient_accumulation_steps`.
|
187 |
+
|
188 |
+
## Deployment on Real-Robots
|
189 |
+
|
190 |
+
We have encapsulated the inference of the model into a class named `RoboticDiffusionTransformerModel` (see [this file](scripts/agilex_model.py#L38)). You can call this class's `step()` method for inference. However, you may need to re-implement some parts according to your specific robot. You should at least modify the `_format_joint_to_state()` (L164) and `_unformat_action_to_joint()` (L196) to convert between robot raw actions and unified action vectors that RDT accepts. You may also specify the control frequency of your robot (L49).
|
191 |
+
|
192 |
+
**IMPORTANT**: When you feed the images into `step()`, remember the order MUST be `[ext_{t-1}, right_wrist_{t-1}, left_wrist_{t-1}, ext_{t}, right_wrist_{t}, left_wrist_{t}]`.
|
193 |
+
|
194 |
+
We provide an example hardware code in [this file](scripts/agilex_inference.py) for deployment on Mobile ALOHA, and the corresponding running script in [this file](inference.sh) (`inference.sh`), which is detailed below;
|
195 |
+
|
196 |
+
```bash
|
197 |
+
python -m scripts.agilex_inference \
|
198 |
+
--use_actions_interpolation \
|
199 |
+
--pretrained_model_name_or_path=<PATH TO MODEL CHECKPOINT> \ # your finetuned checkpoint: e.g., checkpoints/rdt-finetune-1b/checkpoint-<STEP NUMBER>, checkpoints/rdt-finetune-1b/checkpoint-<STEP NUMBER>/pytorch_model/mp_rank_00_model_states.pt, the same before
|
200 |
+
--lang_embeddings_path=<PATH TO YOUR INSTURCTION EMBEDDINGS> \ # e.g. outs/lang_embeddings/your_instr.pt"
|
201 |
+
--ctrl_freq=25 # your control frequency
|
202 |
+
```
|
203 |
+
|
204 |
+
**IMPORTANT**: If you on-board GPU memory is not enough to encode the language, please refer to [this file](scripts/encode_lang.py) for precomputation and specify the language embedding path in `inference.sh`. Detail instructions are provided below:
|
205 |
+
|
206 |
+
1. Set Required Parameters in `scripts/encode_lang.py`
|
207 |
+
|
208 |
+
```python
|
209 |
+
# ...
|
210 |
+
|
211 |
+
GPU = 0
|
212 |
+
MODEL_PATH = "google/t5-v1_1-xxl"
|
213 |
+
CONFIG_PATH = "configs/base.yaml"
|
214 |
+
SAVE_DIR = "outs/" # output directory
|
215 |
+
|
216 |
+
# Modify this to your task name and instruction
|
217 |
+
TASK_NAME = "handover_pan"
|
218 |
+
INSTRUCTION = "Pick up the black marker on the right and put it into the packaging box on the left."
|
219 |
+
|
220 |
+
# Note: if your GPU VRAM is less than 24GB,
|
221 |
+
# it is recommended to enable offloading by specifying an offload directory.
|
222 |
+
OFFLOAD_DIR = None # Specify your offload directory here, ensuring the directory exists.
|
223 |
+
|
224 |
+
# ...
|
225 |
+
```
|
226 |
+
|
227 |
+
2. Run the script
|
228 |
+
```
|
229 |
+
python -m scripts.encode_lang
|
230 |
+
```
|
231 |
+
|
232 |
+
Note: If you want to deploy on the Mobile ALOHA robot, don't forget to install the hardware prerequisites (see [this repo](https://github.com/MarkFzp/mobile-aloha)).
|
233 |
+
|
234 |
+
## Simulation Benchmark
|
235 |
+
|
236 |
+
We comprehensively evaluate RDT against baseline methods using the ManiSkill simulation benchmark. Specifically, we focus on five benchmark tasks: `PegInsertionSide`, `PickCube`, `StackCube`, `PlugCharger`, and `PushCube`. Here's a brief overview of the evaluation setup:
|
237 |
+
|
238 |
+
**Evaluation Setup:**
|
239 |
+
|
240 |
+
1. **Install ManiSkill:**
|
241 |
+
Within the [RDT environment](#installation), install ManiSkill as follows:
|
242 |
+
```bash
|
243 |
+
conda activate rdt
|
244 |
+
pip install --upgrade mani_skill
|
245 |
+
```
|
246 |
+
|
247 |
+
2. **Configure Vulkan:**
|
248 |
+
Follow the [ManiSkill documentation](https://maniskill.readthedocs.io/en/latest/user_guide/getting_started/installation.html#vulkan) to properly set up Vulkan。
|
249 |
+
|
250 |
+
3. **Obtain Model Weights:**
|
251 |
+
Download the fine-tuned model weights from [this Hugging Face repository](https://huggingface.co/robotics-diffusion-transformer/maniskill-model/tree/main/rdt). Download the precomputed language embeddings from [here](https://huggingface.co/robotics-diffusion-transformer/maniskill-model/tree/main/lang_embeds) to the root directory of this repo.
|
252 |
+
|
253 |
+
4. **Run Evaluation Scripts:**
|
254 |
+
After completing the setup steps, execute the provided evaluation scripts to assess RDT on the selected tasks.
|
255 |
+
|
256 |
+
```
|
257 |
+
conda activate rdt
|
258 |
+
python -m eval_sim.eval_rdt_maniskill \
|
259 |
+
--pretrained_path PATH_TO_PRETRAINED_MODEL
|
260 |
+
```
|
261 |
+
|
262 |
+
### Implementation Details
|
263 |
+
|
264 |
+
#### Data
|
265 |
+
|
266 |
+
Utilizing the [official ManiSkill repository](https://github.com/haosulab/ManiSkill), we generated 5,000 trajectories through motion planning. The initial action mode of these trajectories is absolute joint position control and we subsequently converted them into delta end-effector pose control to align with the pre-training action space of OpenVLA and Octo. We strictly adhered to the official codebases of OpenVLA and Octo, modifying only the dataset-loading scripts. Consequently, we finetuned OpenVLA and Octo using the delta end-effector pose data. For RDT and Diffusion-Policy we leverage joint position control data for training which is aligned with our pre-training stage as well.
|
267 |
+
|
268 |
+
#### Training
|
269 |
+
- OpenVLA is fine-tuned from the officially released pre-trained checkpoint with LoRA-rank 32 until converge.
|
270 |
+
- Octo is fine-tuned from the officially released pre-trained checkpoint for 1M iterations until converge.
|
271 |
+
- Diffusion-Policy is trained from scratch for 1000 epochs. We select the checkpoint of 700 epoch which has the lowest validation sample loss of 1e-3.
|
272 |
+
- RDT is fine-tuned from our released pre-trained checkpoint for 300ks iterations.
|
273 |
+
|
274 |
+
#### Results
|
275 |
+
|
276 |
+
Each method is evaluated over 250 trials (10 random seeds with 25 trials per seed). The quantitative results, including success rate mean and std value across 10 random seeds are presented below:
|
277 |
+
|
278 |
+
|
279 |
+
||PegInsertionSide|PickCube|StackCube|PlugCharger|PushCube|Mean|
|
280 |
+
|---|---|---|---|---|---|---|
|
281 |
+
|RDT|**13.2±0.29%**|**77.2±0.48%**|74.0±0.30%|**1.2±0.07%**|**100±0.00%**|**53.6±0.52%**|
|
282 |
+
|OpenVLA|0.0±0.00%|8±0.00%|8±0.00%|0.0±0.00%|8±0.00%|4.8±0.00%|
|
283 |
+
|Octo|0.0±0.00%|0.0±0.00%|0.0±0.00%|0.0±0.00%|0.0±0.00%|0.0±0.00%|
|
284 |
+
|Diffusion-Policy|0.0±0.00%|40.0±0.00%|**80.0±0.00%**|0.0%±0.00%|88.0±0.00%|30.2±0.00%|
|
285 |
+
|
286 |
+
#### Finetune RDT with Maniskill Data
|
287 |
+
|
288 |
+
To fine-tune RDT with Maniskill data, first download the Maniskill data from [here](https://huggingface.co/robotics-diffusion-transformer/maniskill-model) and extract it to `data/datasets/rdt-ft-data`. Then copy the code in `data/hdf5_vla_dataset.py` to `data/hdf5_maniskill_dataset.py` and run the following script:
|
289 |
+
|
290 |
+
```
|
291 |
+
bash finetune_maniskill.sh
|
292 |
+
```
|
293 |
+
|
294 |
+
#### Reproducing Baseline Results
|
295 |
+
|
296 |
+
Download and extract the fine-tuned model weights from [here](https://huggingface.co/robotics-diffusion-transformer/maniskill-model) to `eval_sim/`.
|
297 |
+
|
298 |
+
- OpenVLA: Clone [OpenVLA repo](https://github.com/openvla/openvla) in `./eval_sim/` and install its environment & ManiSkill. Then run the following script:
|
299 |
+
```
|
300 |
+
python -m eval_sim.eval_openvla --pretrained_path PATH_TO_PRETRAINED_MODEL
|
301 |
+
```
|
302 |
+
- Octo: Clone [Octo repo](https://github.com/octo-models/octo.git) in `./eval_sim/` and install its environment & ManiSkill. The run the following script:
|
303 |
+
```
|
304 |
+
python -m eval_sim.eval_octo --pretrained_path PATH_TO_PRETRAINED_MODEL
|
305 |
+
```
|
306 |
+
- Diffusion-Policy: Clone our simplified [Diffusion-Policy repo](https://github.com/LBG21/RDT-Eval-Diffusion-Policy) in `./eval_sim/` and run:
|
307 |
+
```
|
308 |
+
python -m eval_sim.eval_dp --pretrained_path PATH_TO_PRETRAINED_MODEL
|
309 |
+
```
|
310 |
+
|
311 |
+
## FAQ
|
312 |
+
|
313 |
+
### 1. How can I fine-tune RDTs with limited VRAM?
|
314 |
+
|
315 |
+
- **Use a Smaller Model**: Opt for the [RDT-170M model](https://huggingface.co/robotics-diffusion-transformer/rdt-170m), which requires less VRAM.
|
316 |
+
|
317 |
+
- **Select a Memory-Efficient ZeRO Stage**: Choose a more memory-efficient ZeRO stage based on your needs:
|
318 |
+
- **ZeRO-3 with Offload** > **ZeRO-3** > **ZeRO-2 with Offload** > **ZeRO-2** > **ZeRO-1**
|
319 |
+
- By default, we use [ZeRO-2](https://github.com/thu-ml/RoboticsDiffusionTransformer/blob/c68398ed526733faca4eec52cc1a7d15a9f8fea7/finetune.sh#L29) for a balance between speed and memory efficiency. Find more details on ZeRO stages [here](https://huggingface.co/docs/transformers/main/deepspeed#select-a-zero-stage) and [here](https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training).
|
320 |
+
|
321 |
+
- **Enable 8-bit Adam Optimization**: Activate 8-bit Adam by setting [`use_8bit_adam=True`](https://github.com/thu-ml/RoboticsDiffusionTransformer/blob/c68398ed526733faca4eec52cc1a7d15a9f8fea7/main.py#L195) for reduced memory usage during training.
|
322 |
+
|
323 |
+
- **Apply 4-bit or 8-bit Quantization**: Quantizing model weights can significantly reduce VRAM requirements.
|
324 |
+
|
325 |
+
- **Use [XFormers](https://github.com/facebookresearch/xformers)**: This library provides optimized transformers with efficient memory usage.
|
326 |
+
|
327 |
+
- **Enable Gradient Checkpointing**: Implement `gradient_checkpointing` manually to save memory during backpropagation. See [here](https://deepspeed.readthedocs.io/en/latest/activation-checkpointing.html) for instructions. Once you have successfully implemented this feature, we welcome you to submit a PR👏.
|
328 |
+
- **Gradient Accumulation**: Set a larger `--gradient_accumulation_steps=<num_steps>`. This will accumulate the gradients of `<num_steps>` batches for backpropagation. Equivalently, this will increase the batch size by `<num_steps>` times, at the cost of `<num_steps>` times the running time.
|
329 |
+
|
330 |
+
### 2. How many steps are recommended for fine-tuning RDT?
|
331 |
+
|
332 |
+
Regardless of the batch size you select, it is recommended to train for at least 150K steps to achieve optimal results.
|
333 |
+
|
334 |
+
### 3. What to do if t5-xxL is too large to store in GPU memory?
|
335 |
+
|
336 |
+
1. Do not load T5-XXL in your GPU memory when training. Pre-compute language embeddings in advance.
|
337 |
+
2. Set `OFFLOAD_DIR` to enable CPU offloading in `scripts/encode_lang_batch.py` and `scripts/encode_lang.py`.
|
338 |
+
3. Use smaller versions of t5 like t5-base instead of t5-xxL.
|
339 |
+
|
340 |
+
## Citation
|
341 |
+
|
342 |
+
If you find our work helpful, please cite us:
|
343 |
+
|
344 |
+
```bibtex
|
345 |
+
@article{liu2024rdt,
|
346 |
+
title={RDT-1B: a Diffusion Foundation Model for Bimanual Manipulation},
|
347 |
+
author={Liu, Songming and Wu, Lingxuan and Li, Bangguo and Tan, Hengkai and Chen, Huayu and Wang, Zhengyi and Xu, Ke and Su, Hang and Zhu, Jun},
|
348 |
+
journal={arXiv preprint arXiv:2410.07864},
|
349 |
+
year={2024}
|
350 |
+
}
|
351 |
+
```
|
352 |
+
|
353 |
+
Thank you!
|
354 |
+
|
355 |
+
## License
|
356 |
+
|
357 |
+
All the code, model weights, and data are licensed under [MIT license](./LICENSE).
|
Untitled.ipynb
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 3,
|
6 |
+
"id": "71e6c6b4-1e9b-4abb-b36e-80f7fa919c93",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import h5py\n",
|
11 |
+
"import os\n",
|
12 |
+
"import fnmatch"
|
13 |
+
]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "code",
|
17 |
+
"execution_count": 5,
|
18 |
+
"id": "6bc2e70e-970c-4874-8ce2-5950acbf74d6",
|
19 |
+
"metadata": {},
|
20 |
+
"outputs": [],
|
21 |
+
"source": [
|
22 |
+
"dataset_name = 'singlevla_benchmark'\n",
|
23 |
+
"HDF5_DIR = f\"/home/shared/{dataset_name}/\"\n",
|
24 |
+
"DATASET_NAME = dataset_name\n",
|
25 |
+
"\n",
|
26 |
+
"file_paths = []\n",
|
27 |
+
"for root, _, files in os.walk(HDF5_DIR):\n",
|
28 |
+
" for filename in fnmatch.filter(files, '*.hdf5'):\n",
|
29 |
+
" file_path = os.path.join(root, filename)\n",
|
30 |
+
" file_paths.append(file_path)"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": 7,
|
36 |
+
"id": "a43097df-bd96-4300-9473-748ce19406c4",
|
37 |
+
"metadata": {},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"f = h5py.File(file_paths[0], 'r')"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": 9,
|
46 |
+
"id": "3bb5763f-4156-4dba-ac19-66361837afbd",
|
47 |
+
"metadata": {},
|
48 |
+
"outputs": [
|
49 |
+
{
|
50 |
+
"data": {
|
51 |
+
"text/plain": [
|
52 |
+
"<KeysViewHDF5 ['ee_pos', 'joint_pos', 'leftview_image', 'rightview_image']>"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
"execution_count": 9,
|
56 |
+
"metadata": {},
|
57 |
+
"output_type": "execute_result"
|
58 |
+
}
|
59 |
+
],
|
60 |
+
"source": [
|
61 |
+
"f['observation'].keys()"
|
62 |
+
]
|
63 |
+
}
|
64 |
+
],
|
65 |
+
"metadata": {
|
66 |
+
"kernelspec": {
|
67 |
+
"display_name": "Python 3 (ipykernel)",
|
68 |
+
"language": "python",
|
69 |
+
"name": "python3"
|
70 |
+
},
|
71 |
+
"language_info": {
|
72 |
+
"codemirror_mode": {
|
73 |
+
"name": "ipython",
|
74 |
+
"version": 3
|
75 |
+
},
|
76 |
+
"file_extension": ".py",
|
77 |
+
"mimetype": "text/x-python",
|
78 |
+
"name": "python",
|
79 |
+
"nbconvert_exporter": "python",
|
80 |
+
"pygments_lexer": "ipython3",
|
81 |
+
"version": "3.10.15"
|
82 |
+
}
|
83 |
+
},
|
84 |
+
"nbformat": 4,
|
85 |
+
"nbformat_minor": 5
|
86 |
+
}
|
data/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Ignore data files
|
2 |
+
datasets
|
data/compute_dataset_stat.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file will compute the min, max, mean, and standard deviation of each datasets
|
3 |
+
in `pretrain_datasets.json` or `pretrain_datasets.json`.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
# from multiprocessing import Pool, Manager
|
10 |
+
|
11 |
+
import tensorflow as tf
|
12 |
+
import numpy as np
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from data.vla_dataset import VLADataset
|
16 |
+
from data.hdf5_vla_dataset import HDF5VLADataset
|
17 |
+
from data.preprocess import generate_json_state
|
18 |
+
|
19 |
+
|
20 |
+
# Process each dataset to get the statistics
|
21 |
+
@tf.autograph.experimental.do_not_convert
|
22 |
+
def process_dataset(name_dataset_pair):
|
23 |
+
# print(f"PID {os.getpid()} processing {name_dataset_pair[0]}")
|
24 |
+
dataset_iter = name_dataset_pair[1]
|
25 |
+
|
26 |
+
MAX_EPISODES = 100000
|
27 |
+
EPS = 1e-8
|
28 |
+
# For debugging
|
29 |
+
# MAX_EPISODES = 10
|
30 |
+
episode_cnt = 0
|
31 |
+
state_sum = 0
|
32 |
+
state_sum_sq = 0
|
33 |
+
z_state_sum = 0
|
34 |
+
z_state_sum_sq = 0
|
35 |
+
state_cnt = 0
|
36 |
+
nz_state_cnt = None
|
37 |
+
state_max = None
|
38 |
+
state_min = None
|
39 |
+
for episode in dataset_iter:
|
40 |
+
episode_cnt += 1
|
41 |
+
if episode_cnt % 1000 == 0:
|
42 |
+
print(f"Processing episodes {episode_cnt}/{MAX_EPISODES}")
|
43 |
+
if episode_cnt > MAX_EPISODES:
|
44 |
+
break
|
45 |
+
episode_dict = episode['episode_dict']
|
46 |
+
dataset_name = episode['dataset_name']
|
47 |
+
|
48 |
+
res_tup = generate_json_state(
|
49 |
+
episode_dict, dataset_name
|
50 |
+
)
|
51 |
+
states = res_tup[1]
|
52 |
+
|
53 |
+
# Convert to numpy
|
54 |
+
states = states.numpy()
|
55 |
+
|
56 |
+
# Zero the values that are close to zero
|
57 |
+
z_states = states.copy()
|
58 |
+
z_states[np.abs(states) <= EPS] = 0
|
59 |
+
# Compute the non-zero count
|
60 |
+
if nz_state_cnt is None:
|
61 |
+
nz_state_cnt = np.zeros(states.shape[1])
|
62 |
+
nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
|
63 |
+
|
64 |
+
# Update statistics
|
65 |
+
state_sum += np.sum(states, axis=0)
|
66 |
+
state_sum_sq += np.sum(states**2, axis=0)
|
67 |
+
z_state_sum += np.sum(z_states, axis=0)
|
68 |
+
z_state_sum_sq += np.sum(z_states**2, axis=0)
|
69 |
+
state_cnt += states.shape[0]
|
70 |
+
if state_max is None:
|
71 |
+
state_max = np.max(states, axis=0)
|
72 |
+
state_min = np.min(states, axis=0)
|
73 |
+
else:
|
74 |
+
state_max = np.maximum(state_max, np.max(states, axis=0))
|
75 |
+
state_min = np.minimum(state_min, np.min(states, axis=0))
|
76 |
+
|
77 |
+
# Add one to avoid division by zero
|
78 |
+
nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
|
79 |
+
|
80 |
+
result = {
|
81 |
+
"dataset_name": name_dataset_pair[0],
|
82 |
+
"state_mean": (state_sum / state_cnt).tolist(),
|
83 |
+
"state_std": np.sqrt(
|
84 |
+
np.maximum(
|
85 |
+
(z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
|
86 |
+
np.zeros_like(state_sum_sq)
|
87 |
+
)
|
88 |
+
).tolist(),
|
89 |
+
"state_min": state_min.tolist(),
|
90 |
+
"state_max": state_max.tolist(),
|
91 |
+
}
|
92 |
+
|
93 |
+
return result
|
94 |
+
|
95 |
+
|
96 |
+
def process_hdf5_dataset(vla_dataset):
|
97 |
+
EPS = 1e-8
|
98 |
+
episode_cnt = 0
|
99 |
+
state_sum = 0
|
100 |
+
state_sum_sq = 0
|
101 |
+
z_state_sum = 0
|
102 |
+
z_state_sum_sq = 0
|
103 |
+
state_cnt = 0
|
104 |
+
nz_state_cnt = None
|
105 |
+
state_max = None
|
106 |
+
state_min = None
|
107 |
+
for i in tqdm(range(len(vla_dataset))):
|
108 |
+
episode = vla_dataset.get_item(i, state_only=True)
|
109 |
+
episode_cnt += 1
|
110 |
+
|
111 |
+
states = episode['state']
|
112 |
+
|
113 |
+
# Zero the values that are close to zero
|
114 |
+
z_states = states.copy()
|
115 |
+
z_states[np.abs(states) <= EPS] = 0
|
116 |
+
# Compute the non-zero count
|
117 |
+
if nz_state_cnt is None:
|
118 |
+
nz_state_cnt = np.zeros(states.shape[1])
|
119 |
+
nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
|
120 |
+
|
121 |
+
# Update statistics
|
122 |
+
state_sum += np.sum(states, axis=0)
|
123 |
+
state_sum_sq += np.sum(states**2, axis=0)
|
124 |
+
z_state_sum += np.sum(z_states, axis=0)
|
125 |
+
z_state_sum_sq += np.sum(z_states**2, axis=0)
|
126 |
+
state_cnt += states.shape[0]
|
127 |
+
if state_max is None:
|
128 |
+
state_max = np.max(states, axis=0)
|
129 |
+
state_min = np.min(states, axis=0)
|
130 |
+
else:
|
131 |
+
state_max = np.maximum(state_max, np.max(states, axis=0))
|
132 |
+
state_min = np.minimum(state_min, np.min(states, axis=0))
|
133 |
+
|
134 |
+
# Add one to avoid division by zero
|
135 |
+
nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
|
136 |
+
|
137 |
+
result = {
|
138 |
+
"dataset_name": vla_dataset.get_dataset_name(),
|
139 |
+
"state_mean": (state_sum / state_cnt).tolist(),
|
140 |
+
"state_std": np.sqrt(
|
141 |
+
np.maximum(
|
142 |
+
(z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
|
143 |
+
np.zeros_like(state_sum_sq)
|
144 |
+
)
|
145 |
+
).tolist(),
|
146 |
+
"state_min": state_min.tolist(),
|
147 |
+
"state_max": state_max.tolist(),
|
148 |
+
}
|
149 |
+
|
150 |
+
return result
|
151 |
+
|
152 |
+
|
153 |
+
if __name__ == "__main__":
|
154 |
+
parser = argparse.ArgumentParser()
|
155 |
+
# Multiprocessing currently with bugs
|
156 |
+
# parser.add_argument('--n_workers', type=int, default=1,
|
157 |
+
# help="Number of parallel workers.")
|
158 |
+
parser.add_argument('--dataset_type', type=str,
|
159 |
+
default="pretrain",
|
160 |
+
help="Whether to load the pretrain dataset or finetune dataset.")
|
161 |
+
parser.add_argument('--save_path', type=str,
|
162 |
+
default="configs/dataset_stat.json",
|
163 |
+
help="JSON file path to save the dataset statistics.")
|
164 |
+
parser.add_argument('--skip_exist', action='store_true',
|
165 |
+
help="Whether to skip the existing dataset statistics.")
|
166 |
+
parser.add_argument('--hdf5_dataset', action='store_true',
|
167 |
+
help="Whether to load the dataset from the HDF5 files.")
|
168 |
+
args = parser.parse_args()
|
169 |
+
|
170 |
+
if args.hdf5_dataset:
|
171 |
+
vla_dataset = HDF5VLADataset()
|
172 |
+
dataset_name = vla_dataset.get_dataset_name()
|
173 |
+
|
174 |
+
try:
|
175 |
+
with open(args.save_path, 'r') as f:
|
176 |
+
results = json.load(f)
|
177 |
+
except FileNotFoundError:
|
178 |
+
results = {}
|
179 |
+
if args.skip_exist and dataset_name in results:
|
180 |
+
print(f"Skipping existed {dataset_name} dataset statistics")
|
181 |
+
else:
|
182 |
+
print(f"Processing {dataset_name} dataset")
|
183 |
+
result = process_hdf5_dataset(vla_dataset)
|
184 |
+
results[result["dataset_name"]] = result
|
185 |
+
with open(args.save_path, 'w') as f:
|
186 |
+
json.dump(results, f, indent=4)
|
187 |
+
print("All datasets have been processed.")
|
188 |
+
os._exit(0)
|
189 |
+
|
190 |
+
vla_dataset = VLADataset(
|
191 |
+
seed=0, dataset_type=args.dataset_type, repeat=False)
|
192 |
+
name_dataset_pairs = vla_dataset.name2dataset.items()
|
193 |
+
# num_workers = args.n_workers
|
194 |
+
|
195 |
+
for name_dataset_pair in tqdm(name_dataset_pairs):
|
196 |
+
try:
|
197 |
+
with open(args.save_path, 'r') as f:
|
198 |
+
results = json.load(f)
|
199 |
+
except FileNotFoundError:
|
200 |
+
results = {}
|
201 |
+
|
202 |
+
if args.skip_exist and name_dataset_pair[0] in results:
|
203 |
+
print(f"Skipping existed {name_dataset_pair[0]} dataset statistics")
|
204 |
+
continue
|
205 |
+
print(f"Processing {name_dataset_pair[0]} dataset")
|
206 |
+
|
207 |
+
result = process_dataset(name_dataset_pair)
|
208 |
+
|
209 |
+
results[result["dataset_name"]] = result
|
210 |
+
|
211 |
+
# Save the results in the json file after each dataset (for resume)
|
212 |
+
with open(args.save_path, 'w') as f:
|
213 |
+
json.dump(results, f, indent=4)
|
214 |
+
|
215 |
+
print("All datasets have been processed.")
|
216 |
+
|
217 |
+
# with Manager() as manager:
|
218 |
+
# # Create shared dictionary and lock through the manager, accessible by all processes
|
219 |
+
# progress = manager.dict(processed=0, results={})
|
220 |
+
# progress_lock = manager.Lock()
|
221 |
+
|
222 |
+
# # Callback function to update progress
|
223 |
+
# def update_progress(result):
|
224 |
+
# with progress_lock:
|
225 |
+
# progress['processed'] += 1
|
226 |
+
# print(f"{result['dataset_name']} - {progress['processed']}/{len(name_dataset_pairs)} datasets have been processed")
|
227 |
+
# # Append the result to the shared dictionary
|
228 |
+
# progress['results'][result["dataset_name"]] = result
|
229 |
+
|
230 |
+
# with Pool(num_workers) as p:
|
231 |
+
# for name_dataset_pair in name_dataset_pairs:
|
232 |
+
# p.apply_async(process_dataset, args=(name_dataset_pair,), callback=update_progress)
|
233 |
+
|
234 |
+
# # Close the pool and wait for the work to finish
|
235 |
+
# p.close()
|
236 |
+
# p.join()
|
237 |
+
|
238 |
+
# # Save the results in the json file
|
239 |
+
# with open(args.save_path, 'w') as f:
|
240 |
+
# json.dump(progress['results'], f, indent=4)
|
data/compute_dataset_stat_hdf5.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
This file will compute the min, max, mean, and standard deviation of each datasets
|
3 |
+
in `pretrain_datasets.json` or `pretrain_datasets.json`.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
# from data.hdf5_vla_dataset import TabletopHDF5VLADataset as HDF5VLADataset
|
13 |
+
from data.hdf5_vla_dataset import AnubisHDF5VLADataset as HDF5VLADataset
|
14 |
+
|
15 |
+
|
16 |
+
def process_hdf5_dataset(vla_dataset):
|
17 |
+
EPS = 1e-8
|
18 |
+
episode_cnt = 0
|
19 |
+
state_sum = 0
|
20 |
+
state_sum_sq = 0
|
21 |
+
z_state_sum = 0
|
22 |
+
z_state_sum_sq = 0
|
23 |
+
state_cnt = 0
|
24 |
+
nz_state_cnt = None
|
25 |
+
state_max = None
|
26 |
+
state_min = None
|
27 |
+
for i in tqdm(range(len(vla_dataset))):
|
28 |
+
# print(i)
|
29 |
+
episode = vla_dataset.get_item(i, state_only=True)
|
30 |
+
episode_cnt += 1
|
31 |
+
|
32 |
+
states = episode['state']
|
33 |
+
|
34 |
+
# Zero the values that are close to zero
|
35 |
+
z_states = states.copy()
|
36 |
+
z_states[np.abs(states) <= EPS] = 0
|
37 |
+
# Compute the non-zero count
|
38 |
+
if nz_state_cnt is None:
|
39 |
+
nz_state_cnt = np.zeros(states.shape[1])
|
40 |
+
nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)
|
41 |
+
|
42 |
+
# Update statistics
|
43 |
+
state_sum += np.sum(states, axis=0)
|
44 |
+
state_sum_sq += np.sum(states**2, axis=0)
|
45 |
+
z_state_sum += np.sum(z_states, axis=0)
|
46 |
+
z_state_sum_sq += np.sum(z_states**2, axis=0)
|
47 |
+
state_cnt += states.shape[0]
|
48 |
+
if state_max is None:
|
49 |
+
state_max = np.max(states, axis=0)
|
50 |
+
state_min = np.min(states, axis=0)
|
51 |
+
else:
|
52 |
+
state_max = np.maximum(state_max, np.max(states, axis=0))
|
53 |
+
state_min = np.minimum(state_min, np.min(states, axis=0))
|
54 |
+
|
55 |
+
# Add one to avoid division by zero
|
56 |
+
nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))
|
57 |
+
|
58 |
+
result = {
|
59 |
+
"dataset_name": vla_dataset.get_dataset_name(),
|
60 |
+
"state_mean": (state_sum / state_cnt).tolist(),
|
61 |
+
"state_std": np.sqrt(
|
62 |
+
np.maximum(
|
63 |
+
(z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
|
64 |
+
np.zeros_like(state_sum_sq)
|
65 |
+
)
|
66 |
+
).tolist(),
|
67 |
+
"state_min": state_min.tolist(),
|
68 |
+
"state_max": state_max.tolist(),
|
69 |
+
}
|
70 |
+
|
71 |
+
return result
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
parser = argparse.ArgumentParser()
|
76 |
+
parser.add_argument('--save_path', type=str,
|
77 |
+
default="configs/dataset_stat.json",
|
78 |
+
help="JSON file path to save the dataset statistics.")
|
79 |
+
parser.add_argument('--skip_exist', action='store_true',
|
80 |
+
help="Whether to skip the existing dataset statistics.")
|
81 |
+
parser.add_argument('--dataset', type=str)
|
82 |
+
args = parser.parse_args()
|
83 |
+
|
84 |
+
vla_dataset = HDF5VLADataset(args.dataset)
|
85 |
+
dataset_name = vla_dataset.get_dataset_name()
|
86 |
+
|
87 |
+
try:
|
88 |
+
with open(args.save_path, 'r') as f:
|
89 |
+
results = json.load(f)
|
90 |
+
except FileNotFoundError:
|
91 |
+
results = {}
|
92 |
+
if args.skip_exist and dataset_name in results:
|
93 |
+
print(f"Skipping existed {dataset_name} dataset statistics")
|
94 |
+
else:
|
95 |
+
print(f"Processing {dataset_name} dataset")
|
96 |
+
result = process_hdf5_dataset(vla_dataset)
|
97 |
+
results[result["dataset_name"]] = result
|
98 |
+
with open(args.save_path, 'w') as f:
|
99 |
+
json.dump(results, f, indent=4)
|
100 |
+
print("All datasets have been processed.")
|
data/episode_transform.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import tensorflow as tf
|
3 |
+
import yaml
|
4 |
+
|
5 |
+
from data.preprocess import generate_json_state
|
6 |
+
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
7 |
+
|
8 |
+
|
9 |
+
# Read the config
|
10 |
+
with open('configs/base.yaml', 'r') as file:
|
11 |
+
config = yaml.safe_load(file)
|
12 |
+
# Load some constants from the config
|
13 |
+
IMG_HISTORY_SIZE = config['common']['img_history_size']
|
14 |
+
if IMG_HISTORY_SIZE < 1:
|
15 |
+
raise ValueError("Config `img_history_size` must be at least 1.")
|
16 |
+
ACTION_CHUNK_SIZE = config['common']['action_chunk_size']
|
17 |
+
if ACTION_CHUNK_SIZE < 1:
|
18 |
+
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
19 |
+
|
20 |
+
|
21 |
+
@tf.function
|
22 |
+
def process_episode(epsd: dict, dataset_name: str,
|
23 |
+
image_keys: list, image_mask: list) -> dict:
|
24 |
+
"""
|
25 |
+
Process an episode to extract the frames and the json content.
|
26 |
+
"""
|
27 |
+
# Frames of each camera
|
28 |
+
# Ugly code due to tf's poor compatibility
|
29 |
+
frames_0 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
30 |
+
frames_1 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
31 |
+
frames_2 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
32 |
+
frames_3 = tf.TensorArray(dtype=tf.uint8, size=0, dynamic_size=True)
|
33 |
+
# Traverse the episode to collect...
|
34 |
+
for step in iter(epsd['steps']):
|
35 |
+
# Parse the image
|
36 |
+
frames_0 = frames_0.write(frames_0.size(),
|
37 |
+
tf.cond(
|
38 |
+
tf.equal(image_mask[0], 1),
|
39 |
+
lambda: step['observation'][image_keys[0]],
|
40 |
+
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8)
|
41 |
+
))
|
42 |
+
# Very ugly code due to tf's poor compatibility
|
43 |
+
frames_1 = frames_1.write(frames_1.size(),
|
44 |
+
tf.cond(
|
45 |
+
tf.equal(image_mask[1], 1),
|
46 |
+
lambda: step['observation'][image_keys[1]],
|
47 |
+
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8)
|
48 |
+
))
|
49 |
+
# print(image_mask)
|
50 |
+
frames_2 = frames_2.write(frames_2.size(),
|
51 |
+
tf.cond(
|
52 |
+
tf.equal(image_mask[2], 1),
|
53 |
+
lambda: step['observation'][image_keys[2]],
|
54 |
+
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8)
|
55 |
+
))
|
56 |
+
frames_3 = frames_3.write(frames_3.size(),
|
57 |
+
tf.cond(
|
58 |
+
tf.equal(image_mask[3], 1),
|
59 |
+
lambda: step['observation'][image_keys[3]],
|
60 |
+
lambda: tf.zeros([0, 0, 0], dtype=tf.uint8)
|
61 |
+
))
|
62 |
+
|
63 |
+
|
64 |
+
# Calculate the past_frames_0 for each step
|
65 |
+
# Each step has a window of previous frames with size IMG_HISTORY_SIZE
|
66 |
+
# Use the first state to pad the frames
|
67 |
+
# past_frames_0 will have shape (num_steps, IMG_HISTORY_SIZE, height, width, channels)
|
68 |
+
frames_0 = frames_0.stack()
|
69 |
+
first_frame = tf.expand_dims(frames_0[0], axis=0)
|
70 |
+
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE-1, axis=0)
|
71 |
+
padded_frames_0 = tf.concat([first_frame, frames_0], axis=0)
|
72 |
+
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_0)[0] + IMG_HISTORY_SIZE)
|
73 |
+
past_frames_0 = tf.map_fn(
|
74 |
+
lambda i: padded_frames_0[i - IMG_HISTORY_SIZE:i],
|
75 |
+
indices,
|
76 |
+
dtype=tf.uint8
|
77 |
+
)
|
78 |
+
frames_0_time_mask = tf.ones([tf.shape(frames_0)[0]], dtype=tf.bool)
|
79 |
+
padded_frames_0_time_mask = tf.pad(frames_0_time_mask, [[IMG_HISTORY_SIZE-1, 0]], "CONSTANT", constant_values=False)
|
80 |
+
past_frames_0_time_mask = tf.map_fn(
|
81 |
+
lambda i: padded_frames_0_time_mask[i - IMG_HISTORY_SIZE:i],
|
82 |
+
indices,
|
83 |
+
dtype=tf.bool
|
84 |
+
)
|
85 |
+
|
86 |
+
# For past_frames_1
|
87 |
+
frames_1 = frames_1.stack()
|
88 |
+
first_frame = tf.expand_dims(frames_1[0], axis=0)
|
89 |
+
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE-1, axis=0)
|
90 |
+
padded_frames_1 = tf.concat([first_frame, frames_1], axis=0)
|
91 |
+
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_1)[0] + IMG_HISTORY_SIZE)
|
92 |
+
past_frames_1 = tf.map_fn(
|
93 |
+
lambda i: padded_frames_1[i - IMG_HISTORY_SIZE:i],
|
94 |
+
indices,
|
95 |
+
dtype=tf.uint8
|
96 |
+
)
|
97 |
+
frames_1_time_mask = tf.ones([tf.shape(frames_1)[0]], dtype=tf.bool)
|
98 |
+
padded_frames_1_time_mask = tf.pad(frames_1_time_mask, [[IMG_HISTORY_SIZE-1, 0]], "CONSTANT", constant_values=False)
|
99 |
+
past_frames_1_time_mask = tf.map_fn(
|
100 |
+
lambda i: padded_frames_1_time_mask[i - IMG_HISTORY_SIZE:i],
|
101 |
+
indices,
|
102 |
+
dtype=tf.bool
|
103 |
+
)
|
104 |
+
|
105 |
+
# For past_frames_2
|
106 |
+
frames_2 = frames_2.stack()
|
107 |
+
first_frame = tf.expand_dims(frames_2[0], axis=0)
|
108 |
+
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE-1, axis=0)
|
109 |
+
padded_frames_2 = tf.concat([first_frame, frames_2], axis=0)
|
110 |
+
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_2)[0] + IMG_HISTORY_SIZE)
|
111 |
+
past_frames_2 = tf.map_fn(
|
112 |
+
lambda i: padded_frames_2[i - IMG_HISTORY_SIZE:i],
|
113 |
+
indices,
|
114 |
+
dtype=tf.uint8
|
115 |
+
)
|
116 |
+
frames_2_time_mask = tf.ones([tf.shape(frames_2)[0]], dtype=tf.bool)
|
117 |
+
padded_frames_2_time_mask = tf.pad(frames_2_time_mask, [[IMG_HISTORY_SIZE-1, 0]], "CONSTANT", constant_values=False)
|
118 |
+
past_frames_2_time_mask = tf.map_fn(
|
119 |
+
lambda i: padded_frames_2_time_mask[i - IMG_HISTORY_SIZE:i],
|
120 |
+
indices,
|
121 |
+
dtype=tf.bool
|
122 |
+
)
|
123 |
+
|
124 |
+
# For past_frames_3
|
125 |
+
frames_3 = frames_3.stack()
|
126 |
+
first_frame = tf.expand_dims(frames_3[0], axis=0)
|
127 |
+
first_frame = tf.repeat(first_frame, IMG_HISTORY_SIZE-1, axis=0)
|
128 |
+
padded_frames_3 = tf.concat([first_frame, frames_3], axis=0)
|
129 |
+
indices = tf.range(IMG_HISTORY_SIZE, tf.shape(frames_3)[0] + IMG_HISTORY_SIZE)
|
130 |
+
past_frames_3 = tf.map_fn(
|
131 |
+
lambda i: padded_frames_3[i - IMG_HISTORY_SIZE:i],
|
132 |
+
indices,
|
133 |
+
dtype=tf.uint8
|
134 |
+
)
|
135 |
+
frames_3_time_mask = tf.ones([tf.shape(frames_3)[0]], dtype=tf.bool)
|
136 |
+
padded_frames_3_time_mask = tf.pad(frames_3_time_mask, [[IMG_HISTORY_SIZE-1, 0]], "CONSTANT", constant_values=False)
|
137 |
+
past_frames_3_time_mask = tf.map_fn(
|
138 |
+
lambda i: padded_frames_3_time_mask[i - IMG_HISTORY_SIZE:i],
|
139 |
+
indices,
|
140 |
+
dtype=tf.bool
|
141 |
+
)
|
142 |
+
|
143 |
+
# Creat the ids for each step
|
144 |
+
step_id = tf.range(0, tf.shape(frames_0)[0])
|
145 |
+
|
146 |
+
return {
|
147 |
+
'dataset_name': dataset_name,
|
148 |
+
'episode_dict': epsd,
|
149 |
+
'step_id': step_id,
|
150 |
+
'past_frames_0': past_frames_0,
|
151 |
+
'past_frames_0_time_mask': past_frames_0_time_mask,
|
152 |
+
'past_frames_1': past_frames_1,
|
153 |
+
'past_frames_1_time_mask': past_frames_1_time_mask,
|
154 |
+
'past_frames_2': past_frames_2,
|
155 |
+
'past_frames_2_time_mask': past_frames_2_time_mask,
|
156 |
+
'past_frames_3': past_frames_3,
|
157 |
+
'past_frames_3_time_mask': past_frames_3_time_mask,
|
158 |
+
}
|
159 |
+
|
160 |
+
|
161 |
+
@tf.function
|
162 |
+
def bgr_to_rgb(epsd: dict):
|
163 |
+
"""
|
164 |
+
Convert BGR images to RGB images.
|
165 |
+
"""
|
166 |
+
past_frames_0 = epsd['past_frames_0']
|
167 |
+
past_frames_0 = tf.cond(
|
168 |
+
tf.equal(tf.shape(past_frames_0)[-1], 3),
|
169 |
+
lambda: tf.stack([
|
170 |
+
past_frames_0[..., 2],
|
171 |
+
past_frames_0[..., 1],
|
172 |
+
past_frames_0[..., 0]
|
173 |
+
], axis=-1),
|
174 |
+
lambda: past_frames_0
|
175 |
+
)
|
176 |
+
|
177 |
+
past_frames_1 = epsd['past_frames_1']
|
178 |
+
past_frames_1 = tf.cond(
|
179 |
+
tf.equal(tf.shape(past_frames_1)[-1], 3),
|
180 |
+
lambda: tf.stack([
|
181 |
+
past_frames_1[..., 2],
|
182 |
+
past_frames_1[..., 1],
|
183 |
+
past_frames_1[..., 0]
|
184 |
+
], axis=-1),
|
185 |
+
lambda: past_frames_1
|
186 |
+
)
|
187 |
+
|
188 |
+
past_frames_2 = epsd['past_frames_2']
|
189 |
+
past_frames_2 = tf.cond(
|
190 |
+
tf.equal(tf.shape(past_frames_2)[-1], 3),
|
191 |
+
lambda: tf.stack([
|
192 |
+
past_frames_2[..., 2],
|
193 |
+
past_frames_2[..., 1],
|
194 |
+
past_frames_2[..., 0]
|
195 |
+
], axis=-1),
|
196 |
+
lambda: past_frames_2
|
197 |
+
)
|
198 |
+
|
199 |
+
past_frames_3 = epsd['past_frames_3']
|
200 |
+
past_frames_3 = tf.cond(
|
201 |
+
tf.equal(tf.shape(past_frames_3)[-1], 3),
|
202 |
+
lambda: tf.stack([
|
203 |
+
past_frames_3[..., 2],
|
204 |
+
past_frames_3[..., 1],
|
205 |
+
past_frames_3[..., 0]
|
206 |
+
], axis=-1),
|
207 |
+
lambda: past_frames_3
|
208 |
+
)
|
209 |
+
|
210 |
+
return {
|
211 |
+
'dataset_name': epsd['dataset_name'],
|
212 |
+
'episode_dict': epsd['episode_dict'],
|
213 |
+
'step_id': epsd['step_id'],
|
214 |
+
'past_frames_0': past_frames_0,
|
215 |
+
'past_frames_0_time_mask': epsd['past_frames_0_time_mask'],
|
216 |
+
'past_frames_1': past_frames_1,
|
217 |
+
'past_frames_1_time_mask': epsd['past_frames_1_time_mask'],
|
218 |
+
'past_frames_2': past_frames_2,
|
219 |
+
'past_frames_2_time_mask': epsd['past_frames_2_time_mask'],
|
220 |
+
'past_frames_3': past_frames_3,
|
221 |
+
'past_frames_3_time_mask': epsd['past_frames_3_time_mask'],
|
222 |
+
}
|
223 |
+
|
224 |
+
|
225 |
+
def flatten_episode(episode: dict) -> tf.data.Dataset:
|
226 |
+
"""
|
227 |
+
Flatten the episode to a list of steps.
|
228 |
+
"""
|
229 |
+
episode_dict = episode['episode_dict']
|
230 |
+
dataset_name = episode['dataset_name']
|
231 |
+
|
232 |
+
json_content, states, masks = generate_json_state(
|
233 |
+
episode_dict, dataset_name
|
234 |
+
)
|
235 |
+
|
236 |
+
# Calculate the past_states for each step
|
237 |
+
# Each step has a window of previous states with size ACTION_CHUNK_SIZE
|
238 |
+
# Use the first state to pad the states
|
239 |
+
# past_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
240 |
+
first_state = tf.expand_dims(states[0], axis=0)
|
241 |
+
first_state = tf.repeat(first_state, ACTION_CHUNK_SIZE-1, axis=0)
|
242 |
+
padded_states = tf.concat([first_state, states], axis=0)
|
243 |
+
indices = tf.range(ACTION_CHUNK_SIZE, tf.shape(states)[0] + ACTION_CHUNK_SIZE)
|
244 |
+
past_states = tf.map_fn(
|
245 |
+
lambda i: padded_states[i - ACTION_CHUNK_SIZE:i],
|
246 |
+
indices,
|
247 |
+
dtype=tf.float32
|
248 |
+
)
|
249 |
+
states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
|
250 |
+
padded_states_time_mask = tf.pad(states_time_mask, [[ACTION_CHUNK_SIZE-1, 0]], "CONSTANT", constant_values=False)
|
251 |
+
past_states_time_mask = tf.map_fn(
|
252 |
+
lambda i: padded_states_time_mask[i - ACTION_CHUNK_SIZE:i],
|
253 |
+
indices,
|
254 |
+
dtype=tf.bool
|
255 |
+
)
|
256 |
+
|
257 |
+
# Calculate the future_states for each step
|
258 |
+
# Each step has a window of future states with size ACTION_CHUNK_SIZE
|
259 |
+
# Use the last state to pad the states
|
260 |
+
# future_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
261 |
+
last_state = tf.expand_dims(states[-1], axis=0)
|
262 |
+
last_state = tf.repeat(last_state, ACTION_CHUNK_SIZE, axis=0)
|
263 |
+
padded_states = tf.concat([states, last_state], axis=0)
|
264 |
+
indices = tf.range(1, tf.shape(states)[0] + 1)
|
265 |
+
future_states = tf.map_fn(
|
266 |
+
lambda i: padded_states[i:i + ACTION_CHUNK_SIZE],
|
267 |
+
indices,
|
268 |
+
dtype=tf.float32
|
269 |
+
)
|
270 |
+
states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
|
271 |
+
padded_states_time_mask = tf.pad(states_time_mask, [[0, ACTION_CHUNK_SIZE]], "CONSTANT", constant_values=False)
|
272 |
+
future_states_time_mask = tf.map_fn(
|
273 |
+
lambda i: padded_states_time_mask[i:i + ACTION_CHUNK_SIZE],
|
274 |
+
indices,
|
275 |
+
dtype=tf.bool
|
276 |
+
)
|
277 |
+
|
278 |
+
# Calculate the mean and std for state
|
279 |
+
state_std = tf.math.reduce_std(states, axis=0, keepdims=True)
|
280 |
+
state_std = tf.repeat(state_std, tf.shape(states)[0], axis=0)
|
281 |
+
state_mean = tf.math.reduce_mean(states, axis=0, keepdims=True)
|
282 |
+
state_mean = tf.repeat(state_mean, tf.shape(states)[0], axis=0)
|
283 |
+
|
284 |
+
state_norm = tf.math.reduce_mean(
|
285 |
+
tf.math.square(states), axis=0, keepdims=True)
|
286 |
+
state_norm = tf.math.sqrt(state_norm)
|
287 |
+
state_norm = tf.repeat(state_norm, tf.shape(states)[0], axis=0)
|
288 |
+
|
289 |
+
# Create a list of steps
|
290 |
+
step_data = []
|
291 |
+
for i in range(tf.shape(states)[0]):
|
292 |
+
step_data.append({
|
293 |
+
'step_id': episode['step_id'][i],
|
294 |
+
'json_content': json_content,
|
295 |
+
'state_chunk': past_states[i],
|
296 |
+
'state_chunk_time_mask': past_states_time_mask[i],
|
297 |
+
'action_chunk': future_states[i],
|
298 |
+
'action_chunk_time_mask': future_states_time_mask[i],
|
299 |
+
'state_vec_mask': masks[i],
|
300 |
+
'past_frames_0': episode['past_frames_0'][i],
|
301 |
+
'past_frames_0_time_mask': episode['past_frames_0_time_mask'][i],
|
302 |
+
'past_frames_1': episode['past_frames_1'][i],
|
303 |
+
'past_frames_1_time_mask': episode['past_frames_1_time_mask'][i],
|
304 |
+
'past_frames_2': episode['past_frames_2'][i],
|
305 |
+
'past_frames_2_time_mask': episode['past_frames_2_time_mask'][i],
|
306 |
+
'past_frames_3': episode['past_frames_3'][i],
|
307 |
+
'past_frames_3_time_mask': episode['past_frames_3_time_mask'][i],
|
308 |
+
'state_std': state_std[i],
|
309 |
+
'state_mean': state_mean[i],
|
310 |
+
'state_norm': state_norm[i],
|
311 |
+
})
|
312 |
+
|
313 |
+
return step_data
|
314 |
+
|
315 |
+
|
316 |
+
def flatten_episode_agilex(episode: dict) -> tf.data.Dataset:
|
317 |
+
"""
|
318 |
+
Flatten the episode to a list of steps.
|
319 |
+
"""
|
320 |
+
episode_dict = episode['episode_dict']
|
321 |
+
dataset_name = episode['dataset_name']
|
322 |
+
|
323 |
+
json_content, states, masks, acts = generate_json_state(
|
324 |
+
episode_dict, dataset_name
|
325 |
+
)
|
326 |
+
|
327 |
+
# Calculate the past_states for each step
|
328 |
+
# Each step has a window of previous states with size ACTION_CHUNK_SIZE
|
329 |
+
# Use the first state to pad the states
|
330 |
+
# past_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
331 |
+
first_state = tf.expand_dims(states[0], axis=0)
|
332 |
+
first_state = tf.repeat(first_state, ACTION_CHUNK_SIZE-1, axis=0)
|
333 |
+
padded_states = tf.concat([first_state, states], axis=0)
|
334 |
+
indices = tf.range(ACTION_CHUNK_SIZE, tf.shape(states)[0] + ACTION_CHUNK_SIZE)
|
335 |
+
past_states = tf.map_fn(
|
336 |
+
lambda i: padded_states[i - ACTION_CHUNK_SIZE:i],
|
337 |
+
indices,
|
338 |
+
dtype=tf.float32
|
339 |
+
)
|
340 |
+
states_time_mask = tf.ones([tf.shape(states)[0]], dtype=tf.bool)
|
341 |
+
padded_states_time_mask = tf.pad(states_time_mask, [[ACTION_CHUNK_SIZE-1, 0]], "CONSTANT", constant_values=False)
|
342 |
+
past_states_time_mask = tf.map_fn(
|
343 |
+
lambda i: padded_states_time_mask[i - ACTION_CHUNK_SIZE:i],
|
344 |
+
indices,
|
345 |
+
dtype=tf.bool
|
346 |
+
)
|
347 |
+
|
348 |
+
# NOTE bg the future states shall be actions
|
349 |
+
# Calculate the future_states for each step
|
350 |
+
# Each step has a window of future states with size ACTION_CHUNK_SIZE
|
351 |
+
# Use the last action to pad the states
|
352 |
+
# future_states will have shape (num_steps, ACTION_CHUNK_SIZE, state_dim)
|
353 |
+
last_act = tf.expand_dims(acts[-1], axis=0)
|
354 |
+
last_act = tf.repeat(last_act, ACTION_CHUNK_SIZE, axis=0)
|
355 |
+
padded_states = tf.concat([acts, last_act], axis=0)
|
356 |
+
# indices = tf.range(1, tf.shape(states)[0] + 1)
|
357 |
+
indices = tf.range(0, tf.shape(acts)[0]) # NOTE time 0 action = time 1 state
|
358 |
+
future_states = tf.map_fn(
|
359 |
+
lambda i: padded_states[i:i + ACTION_CHUNK_SIZE],
|
360 |
+
indices,
|
361 |
+
dtype=tf.float32
|
362 |
+
)
|
363 |
+
states_time_mask = tf.ones([tf.shape(acts)[0]], dtype=tf.bool)
|
364 |
+
padded_states_time_mask = tf.pad(states_time_mask, [[0, ACTION_CHUNK_SIZE]], "CONSTANT", constant_values=False)
|
365 |
+
future_states_time_mask = tf.map_fn(
|
366 |
+
lambda i: padded_states_time_mask[i:i + ACTION_CHUNK_SIZE],
|
367 |
+
indices,
|
368 |
+
dtype=tf.bool
|
369 |
+
)
|
370 |
+
|
371 |
+
# Calculate the std and mean for state
|
372 |
+
state_std = tf.math.reduce_std(states, axis=0, keepdims=True)
|
373 |
+
state_std = tf.repeat(state_std, tf.shape(states)[0], axis=0)
|
374 |
+
state_mean = tf.math.reduce_mean(states, axis=0, keepdims=True)
|
375 |
+
state_mean = tf.repeat(state_mean, tf.shape(states)[0], axis=0)
|
376 |
+
|
377 |
+
state_norm = tf.math.reduce_mean(
|
378 |
+
tf.math.square(acts), axis=0, keepdims=True)
|
379 |
+
state_norm = tf.math.sqrt(state_norm)
|
380 |
+
state_norm = tf.repeat(state_norm, tf.shape(states)[0], axis=0)
|
381 |
+
|
382 |
+
# Create a list of steps
|
383 |
+
step_data = []
|
384 |
+
for i in range(tf.shape(states)[0]):
|
385 |
+
step_data.append({
|
386 |
+
'step_id': episode['step_id'][i],
|
387 |
+
'json_content': json_content,
|
388 |
+
'state_chunk': past_states[i],
|
389 |
+
'state_chunk_time_mask': past_states_time_mask[i],
|
390 |
+
'action_chunk': future_states[i],
|
391 |
+
'action_chunk_time_mask': future_states_time_mask[i],
|
392 |
+
'state_vec_mask': masks[i],
|
393 |
+
'past_frames_0': episode['past_frames_0'][i],
|
394 |
+
'past_frames_0_time_mask': episode['past_frames_0_time_mask'][i],
|
395 |
+
'past_frames_1': episode['past_frames_1'][i],
|
396 |
+
'past_frames_1_time_mask': episode['past_frames_1_time_mask'][i],
|
397 |
+
'past_frames_2': episode['past_frames_2'][i],
|
398 |
+
'past_frames_2_time_mask': episode['past_frames_2_time_mask'][i],
|
399 |
+
'past_frames_3': episode['past_frames_3'][i],
|
400 |
+
'past_frames_3_time_mask': episode['past_frames_3_time_mask'][i],
|
401 |
+
'state_std': state_std[i],
|
402 |
+
'state_mean': state_mean[i],
|
403 |
+
'state_norm': state_norm[i],
|
404 |
+
})
|
405 |
+
|
406 |
+
return step_data
|
data/filelock.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fcntl
|
2 |
+
|
3 |
+
|
4 |
+
class FileLock:
|
5 |
+
"""
|
6 |
+
A file lock class.
|
7 |
+
"""
|
8 |
+
def __init__(self, filename):
|
9 |
+
self.filename = filename
|
10 |
+
self.handle = None
|
11 |
+
|
12 |
+
def acquire_read_lock(self):
|
13 |
+
self.handle = open(self.filename + '.lock', 'r')
|
14 |
+
fcntl.flock(self.handle, fcntl.LOCK_SH | fcntl.LOCK_NB)
|
15 |
+
|
16 |
+
def acquire_write_lock(self):
|
17 |
+
self.handle = open(self.filename + '.lock', 'w')
|
18 |
+
fcntl.flock(self.handle, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
19 |
+
|
20 |
+
def release_lock(self):
|
21 |
+
if self.handle is not None:
|
22 |
+
fcntl.flock(self.handle, fcntl.LOCK_UN)
|
23 |
+
self.handle.close()
|
24 |
+
self.handle = None
|
data/hdf5_maniskill_dataset.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import h5py
|
3 |
+
import yaml
|
4 |
+
import numpy as np
|
5 |
+
# Assuming STATE_VEC_IDX_MAPPING is a dictionary mapping state variable names to indices
|
6 |
+
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
7 |
+
import glob
|
8 |
+
from scipy.interpolate import interp1d
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
|
12 |
+
def interpolate_action_sequence(action_sequence, target_size):
|
13 |
+
"""
|
14 |
+
Extend the action sequece to `target_size` by linear interpolation.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
action_sequence (np.ndarray): original action sequence, shape (N, D).
|
18 |
+
target_size (int): target sequence length.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
extended_sequence (np.ndarray): extended action sequence, shape (target_size, D).
|
22 |
+
"""
|
23 |
+
N, D = action_sequence.shape
|
24 |
+
indices_old = np.arange(N)
|
25 |
+
indices_new = np.linspace(0, N - 1, target_size)
|
26 |
+
|
27 |
+
interp_func = interp1d(indices_old, action_sequence,
|
28 |
+
kind='linear', axis=0, assume_sorted=True)
|
29 |
+
action_sequence_new = interp_func(indices_new)
|
30 |
+
|
31 |
+
return action_sequence_new
|
32 |
+
|
33 |
+
|
34 |
+
class HDF5VLADataset:
|
35 |
+
"""
|
36 |
+
This class is used to sample episodes from the embodiment dataset
|
37 |
+
stored in HDF5 files.
|
38 |
+
"""
|
39 |
+
def __init__(self):
|
40 |
+
# The name of your dataset
|
41 |
+
self.DATASET_NAME = "agilex"
|
42 |
+
|
43 |
+
self.data_dir = "data/datasets/rdt-ft-data/demo_1k"
|
44 |
+
self.tasks = os.listdir(self.data_dir)
|
45 |
+
# Multiple tasks
|
46 |
+
self.tasks = ['PickCube-v1', 'StackCube-v1', 'PlugCharger-v1', 'PushCube-v1', 'PegInsertionSide-v1']
|
47 |
+
# Load configuration from YAML file
|
48 |
+
with open('configs/base.yaml', 'r') as file:
|
49 |
+
config = yaml.safe_load(file)
|
50 |
+
self.CHUNK_SIZE = config['common']['action_chunk_size']
|
51 |
+
self.IMG_HISTORY_SIZE = config['common']['img_history_size']
|
52 |
+
self.STATE_DIM = config['common']['state_dim']
|
53 |
+
|
54 |
+
self.num_episode_per_task = 1000
|
55 |
+
self.img = []
|
56 |
+
self.state = []
|
57 |
+
self.action = []
|
58 |
+
|
59 |
+
# open the hdf5 files in memory to speed up the data loading
|
60 |
+
for task in self.tasks:
|
61 |
+
file_path = glob.glob(os.path.join(self.data_dir, task, 'motionplanning', '*.h5'))[0]
|
62 |
+
with h5py.File(file_path, "r") as f:
|
63 |
+
trajs = f.keys() # traj_0, traj_1,
|
64 |
+
# sort by the traj number
|
65 |
+
trajs = sorted(trajs, key=lambda x: int(x.split('_')[-1]))
|
66 |
+
for traj in trajs:
|
67 |
+
# images = f[traj]['obs']['sensor_data']['base_camera']['rgb'][:]
|
68 |
+
states = f[traj]['obs']['agent']['qpos'][:]
|
69 |
+
actions = f[traj]['actions'][:]
|
70 |
+
|
71 |
+
self.state.append(states)
|
72 |
+
self.action.append(actions)
|
73 |
+
# self.img.append(images)
|
74 |
+
|
75 |
+
self.state_min = np.concatenate(self.state).min(axis=0)
|
76 |
+
self.state_max = np.concatenate(self.state).max(axis=0)
|
77 |
+
self.action_min = np.concatenate(self.action).min(axis=0)
|
78 |
+
self.action_max = np.concatenate(self.action).max(axis=0)
|
79 |
+
self.action_std = np.concatenate(self.action).std(axis=0)
|
80 |
+
self.action_mean = np.concatenate(self.action).mean(axis=0)
|
81 |
+
|
82 |
+
self.task2lang = {
|
83 |
+
"PegInsertionSide-v1": "Pick up a orange-white peg and insert the orange end into the box with a hole in it.",
|
84 |
+
"PickCube-v1": "Grasp a red cube and move it to a target goal position.",
|
85 |
+
"StackCube-v1": "Pick up a red cube and stack it on top of a green cube and let go of the cube without it falling.",
|
86 |
+
"PlugCharger-v1": "Pick up one of the misplaced shapes on the board/kit and insert it into the correct empty slot.",
|
87 |
+
"PushCube-v1": "Push and move a cube to a goal region in front of it."
|
88 |
+
}
|
89 |
+
|
90 |
+
def __len__(self):
|
91 |
+
# Assume each file contains 100 episodes
|
92 |
+
return len(self.tasks) * self.num_episode_per_task
|
93 |
+
|
94 |
+
def get_dataset_name(self):
|
95 |
+
return self.DATASET_NAME
|
96 |
+
|
97 |
+
def get_item(self, index=None):
|
98 |
+
"""
|
99 |
+
Get a training sample at a random timestep.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
index (int, optional): The index of the episode.
|
103 |
+
If not provided, a random episode will be selected.
|
104 |
+
state_only (bool, optional): Whether to return only the state.
|
105 |
+
In this way, the sample will contain a complete trajectory rather
|
106 |
+
than a single timestep. Defaults to False.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
sample (dict): A dictionary containing the training sample.
|
110 |
+
"""
|
111 |
+
while True:
|
112 |
+
if index is None:
|
113 |
+
index = np.random.randint(0, self.__len__())
|
114 |
+
valid, sample = self.parse_hdf5_file(index)
|
115 |
+
if valid:
|
116 |
+
return sample
|
117 |
+
else:
|
118 |
+
index = np.random.randint(0, self.__len__())
|
119 |
+
|
120 |
+
def parse_hdf5_file(self, index):
|
121 |
+
"""
|
122 |
+
Parse an HDF5 file to generate a training sample at a random timestep.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
file_path (str): The path to the HDF5 file.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
valid (bool): Whether the episode is valid.
|
129 |
+
dict: A dictionary containing the training sample.
|
130 |
+
"""
|
131 |
+
num_steps = len(self.action[index])
|
132 |
+
step_index = np.random.randint(0, num_steps)
|
133 |
+
task_index = index // self.num_episode_per_task
|
134 |
+
language = self.task2lang[self.tasks[task_index]]
|
135 |
+
task_inner_index = index % self.num_episode_per_task
|
136 |
+
# Skip these episodes since in the eef version dataset they are invalid.
|
137 |
+
if self.tasks[task_index] == 'PegInsertionSide-v1' and task_inner_index > 400:
|
138 |
+
return False, None
|
139 |
+
proc_index = task_inner_index // 100
|
140 |
+
episode_index = task_inner_index % 100
|
141 |
+
# images0 = self.img[index]
|
142 |
+
# normalize to -1, 1
|
143 |
+
states = (self.state[index] - self.state_min) / (self.state_max - self.state_min) * 2 - 1
|
144 |
+
states = states[:, :-1] # remove the last state as it is replicate of the -2 state
|
145 |
+
actions = (self.action[index] - self.action_min) / (self.action_max - self.action_min) * 2 - 1
|
146 |
+
|
147 |
+
# Get image history
|
148 |
+
start_img_idx = max(0, step_index - self.IMG_HISTORY_SIZE + 1)
|
149 |
+
end_img_idx = step_index + 1
|
150 |
+
img_history = []
|
151 |
+
for i in range(start_img_idx, end_img_idx):
|
152 |
+
img_path = os.path.join(self.data_dir, self.tasks[task_index], 'motionplanning', f'{proc_index}', f'{episode_index}', f"{i + 1}.png")
|
153 |
+
img = np.array(Image.open(img_path))
|
154 |
+
img_history.append(img)
|
155 |
+
img_history = np.array(img_history)
|
156 |
+
# img_history = images0[start_img_idx:end_img_idx]
|
157 |
+
img_valid_len = img_history.shape[0]
|
158 |
+
|
159 |
+
# Pad images if necessary
|
160 |
+
if img_valid_len < self.IMG_HISTORY_SIZE:
|
161 |
+
padding = np.tile(img_history[0:1], (self.IMG_HISTORY_SIZE - img_valid_len, 1, 1, 1))
|
162 |
+
img_history = np.concatenate([padding, img_history], axis=0)
|
163 |
+
|
164 |
+
img_history_mask = np.array(
|
165 |
+
[False] * (self.IMG_HISTORY_SIZE - img_valid_len) + [True] * img_valid_len
|
166 |
+
)
|
167 |
+
|
168 |
+
# Compute state statistics
|
169 |
+
state_std = np.std(states, axis=0)
|
170 |
+
state_mean = np.mean(states, axis=0)
|
171 |
+
state_norm = np.sqrt(np.mean(states ** 2, axis=0))
|
172 |
+
|
173 |
+
# Get state and action at the specified timestep
|
174 |
+
state = states[step_index: step_index + 1]
|
175 |
+
runtime_chunksize = self.CHUNK_SIZE // 4
|
176 |
+
action_sequence = actions[step_index: step_index + runtime_chunksize]
|
177 |
+
# we use linear interpolation to pad the action sequence
|
178 |
+
|
179 |
+
# Pad action sequence if necessary
|
180 |
+
if action_sequence.shape[0] < runtime_chunksize:
|
181 |
+
padding = np.tile(action_sequence[-1:], (runtime_chunksize - action_sequence.shape[0], 1))
|
182 |
+
action_sequence = np.concatenate([action_sequence, padding], axis=0)
|
183 |
+
|
184 |
+
action_sequence = interpolate_action_sequence(action_sequence, self.CHUNK_SIZE)
|
185 |
+
|
186 |
+
# Fill state and action into unified vectors
|
187 |
+
def fill_in_state(values):
|
188 |
+
UNI_STATE_INDICES = [
|
189 |
+
STATE_VEC_IDX_MAPPING[f"right_arm_joint_{i}_pos"] for i in range(7)
|
190 |
+
] + [
|
191 |
+
STATE_VEC_IDX_MAPPING[f"right_gripper_open"]
|
192 |
+
]
|
193 |
+
uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM,))
|
194 |
+
uni_vec[..., UNI_STATE_INDICES] = values
|
195 |
+
return uni_vec
|
196 |
+
|
197 |
+
state_indicator = fill_in_state(np.ones_like(state_std))
|
198 |
+
state = fill_in_state(state)
|
199 |
+
state_std = fill_in_state(state_std)
|
200 |
+
state_mean = fill_in_state(state_mean)
|
201 |
+
state_norm = fill_in_state(state_norm)
|
202 |
+
action_sequence = fill_in_state(action_sequence)
|
203 |
+
|
204 |
+
# Assemble the meta information
|
205 |
+
meta = {
|
206 |
+
"dataset_name": self.DATASET_NAME,
|
207 |
+
"#steps": num_steps,
|
208 |
+
"step_id": step_index,
|
209 |
+
"instruction": language
|
210 |
+
}
|
211 |
+
|
212 |
+
# Return the resulting sample
|
213 |
+
return True, {
|
214 |
+
"meta": meta,
|
215 |
+
"state": state,
|
216 |
+
"state_std": state_std,
|
217 |
+
"state_mean": state_mean,
|
218 |
+
"state_norm": state_norm,
|
219 |
+
"actions": action_sequence,
|
220 |
+
"state_indicator": state_indicator,
|
221 |
+
"cam_high": img_history, # Assuming images0 are high-level camera images
|
222 |
+
"cam_high_mask": img_history_mask,
|
223 |
+
"cam_left_wrist": np.zeros((self.IMG_HISTORY_SIZE, 0, 0, 0)),
|
224 |
+
"cam_left_wrist_mask": np.zeros(self.IMG_HISTORY_SIZE, dtype=bool),
|
225 |
+
"cam_right_wrist": np.zeros((self.IMG_HISTORY_SIZE, 0, 0, 0)),
|
226 |
+
"cam_right_wrist_mask": np.zeros(self.IMG_HISTORY_SIZE, dtype=bool),
|
227 |
+
}
|
228 |
+
|
229 |
+
|
230 |
+
if __name__ == "__main__":
|
231 |
+
from PIL import Image
|
232 |
+
|
233 |
+
ds = HDF5VLADataset()
|
234 |
+
|
235 |
+
json_data = {
|
236 |
+
'state_min': ds.state_min.tolist(),
|
237 |
+
'state_max': ds.state_max.tolist(),
|
238 |
+
'action_min': ds.action_min.tolist(),
|
239 |
+
'action_max': ds.action_max.tolist(),
|
240 |
+
}
|
241 |
+
print(json_data)
|
242 |
+
|
243 |
+
|
data/hdf5_vla_dataset.py
ADDED
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import fnmatch
|
3 |
+
import json
|
4 |
+
|
5 |
+
import h5py
|
6 |
+
import yaml
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from configs.state_vec import STATE_VEC_IDX_MAPPING
|
11 |
+
TABLETOP_6D_INDICES_NAMES = [
|
12 |
+
'left_eef_pos_x','left_eef_pos_y','left_eef_pos_z','left_eef_angle_0','left_eef_angle_1','left_eef_angle_2','left_eef_angle_3','left_eef_angle_4','left_eef_angle_5','left_gripper_open','right_eef_pos_x','right_eef_pos_y','right_eef_pos_z','right_eef_angle_0','right_eef_angle_1','right_eef_angle_2','right_eef_angle_3','right_eef_angle_4','right_eef_angle_5','right_gripper_open']
|
13 |
+
TABLETOP_6D_INDICES = [STATE_VEC_IDX_MAPPING[n] for n in TABLETOP_6D_INDICES_NAMES]
|
14 |
+
|
15 |
+
class TabletopHDF5VLADataset:
|
16 |
+
"""
|
17 |
+
This class is used to sample episodes from the embododiment dataset
|
18 |
+
stored in HDF5.
|
19 |
+
"""
|
20 |
+
def __init__(self, task_name) -> None:
|
21 |
+
# [Modify] The path to the HDF5 dataset directory
|
22 |
+
# Each HDF5 file contains one episode
|
23 |
+
dataset_name = task_name
|
24 |
+
HDF5_DIR = f"/data5/jellyho/tabletop/{dataset_name}/"
|
25 |
+
self.DATASET_NAME = dataset_name
|
26 |
+
|
27 |
+
self.file_paths = []
|
28 |
+
for root, _, files in os.walk(HDF5_DIR):
|
29 |
+
for filename in fnmatch.filter(files, '*.hdf5'):
|
30 |
+
file_path = os.path.join(root, filename)
|
31 |
+
self.file_paths.append(file_path)
|
32 |
+
|
33 |
+
# Load the config
|
34 |
+
with open('configs/base.yaml', 'r') as file:
|
35 |
+
config = yaml.safe_load(file)
|
36 |
+
self.CHUNK_SIZE = config['common']['action_chunk_size']
|
37 |
+
self.IMG_HISORY_SIZE = config['common']['img_history_size']
|
38 |
+
self.STATE_DIM = config['common']['state_dim']
|
39 |
+
|
40 |
+
# Get each episode's len
|
41 |
+
episode_lens = []
|
42 |
+
for file_path in self.file_paths:
|
43 |
+
valid, res = self.parse_hdf5_file_state_only(file_path)
|
44 |
+
_len = res['state'].shape[0] if valid else 0
|
45 |
+
episode_lens.append(_len)
|
46 |
+
self.episode_sample_weights = np.array(episode_lens) / np.sum(episode_lens)
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return len(self.file_paths)
|
50 |
+
|
51 |
+
def get_dataset_name(self):
|
52 |
+
return self.DATASET_NAME
|
53 |
+
|
54 |
+
def get_item(self, index: int=None, state_only=False):
|
55 |
+
"""Get a training sample at a random timestep.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
index (int, optional): the index of the episode.
|
59 |
+
If not provided, a random episode will be selected.
|
60 |
+
state_only (bool, optional): Whether to return only the state.
|
61 |
+
In this way, the sample will contain a complete trajectory rather
|
62 |
+
than a single timestep. Defaults to False.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
sample (dict): a dictionary containing the training sample.
|
66 |
+
"""
|
67 |
+
while True:
|
68 |
+
if index is None:
|
69 |
+
file_path = np.random.choice(self.file_paths, p=self.episode_sample_weights)
|
70 |
+
else:
|
71 |
+
file_path = self.file_paths[index]
|
72 |
+
valid, sample = self.parse_hdf5_file(file_path) \
|
73 |
+
if not state_only else self.parse_hdf5_file_state_only(file_path)
|
74 |
+
if valid:
|
75 |
+
return sample
|
76 |
+
else:
|
77 |
+
index = np.random.randint(0, len(self.file_paths))
|
78 |
+
|
79 |
+
def parse_hdf5_file(self, file_path):
|
80 |
+
"""[Modify] Parse a hdf5 file to generate a training sample at
|
81 |
+
a random timestep.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
file_path (str): the path to the hdf5 file
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
valid (bool): whether the episode is valid, which is useful for filtering.
|
88 |
+
If False, this episode will be dropped.
|
89 |
+
dict: a dictionary containing the training sample,
|
90 |
+
{
|
91 |
+
"meta": {
|
92 |
+
"dataset_name": str, # the name of your dataset.
|
93 |
+
"#steps": int, # the number of steps in the episode,
|
94 |
+
# also the total timesteps.
|
95 |
+
"instruction": str # the language instruction for this episode.
|
96 |
+
},
|
97 |
+
"step_id": int, # the index of the sampled step,
|
98 |
+
# also the timestep t.
|
99 |
+
"state": ndarray, # state[t], (1, STATE_DIM).
|
100 |
+
"state_std": ndarray, # std(state[:]), (STATE_DIM,).
|
101 |
+
"state_mean": ndarray, # mean(state[:]), (STATE_DIM,).
|
102 |
+
"state_norm": ndarray, # norm(state[:]), (STATE_DIM,).
|
103 |
+
"actions": ndarray, # action[t:t+CHUNK_SIZE], (CHUNK_SIZE, STATE_DIM).
|
104 |
+
"state_indicator", ndarray, # indicates the validness of each dim, (STATE_DIM,).
|
105 |
+
"cam_high": ndarray, # external camera image, (IMG_HISORY_SIZE, H, W, 3)
|
106 |
+
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
|
107 |
+
"cam_high_mask": ndarray, # indicates the validness of each timestep, (IMG_HISORY_SIZE,) boolean array.
|
108 |
+
# For the first IMAGE_HISTORY_SIZE-1 timesteps, the mask should be False.
|
109 |
+
"cam_left_wrist": ndarray, # left wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
|
110 |
+
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
|
111 |
+
"cam_left_wrist_mask": ndarray,
|
112 |
+
"cam_right_wrist": ndarray, # right wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
|
113 |
+
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
|
114 |
+
# If only one wrist, make it right wrist, plz.
|
115 |
+
"cam_right_wrist_mask": ndarray
|
116 |
+
} or None if the episode is invalid.
|
117 |
+
"""
|
118 |
+
with h5py.File(file_path, 'r') as f:
|
119 |
+
states = f['observations']['states']['ee_6d_pos'][:]
|
120 |
+
actions = f['actions']['ee_6d_pos'][:]
|
121 |
+
num_steps = states.shape[0]
|
122 |
+
# [Optional] We drop too-short episode
|
123 |
+
if num_steps < 20:
|
124 |
+
return False, None
|
125 |
+
|
126 |
+
# We randomly sample a timestep
|
127 |
+
step_id = np.random.randint(0, num_steps)
|
128 |
+
|
129 |
+
# You can also use precomputed language embeddings (recommended)
|
130 |
+
if self.DATASET_NAME == 'aloha_box_into_pot_easy':
|
131 |
+
instruction = f['observations']['states']['language_instruction'][0].decode('utf-8')
|
132 |
+
else:
|
133 |
+
instruction = f"lang_embed/{self.DATASET_NAME}.pt"
|
134 |
+
|
135 |
+
# Assemble the meta
|
136 |
+
meta = {
|
137 |
+
"dataset_name": self.DATASET_NAME,
|
138 |
+
"#steps": num_steps,
|
139 |
+
"step_id": step_id,
|
140 |
+
"instruction": instruction
|
141 |
+
}
|
142 |
+
|
143 |
+
# Rescale gripper to [0, 1]
|
144 |
+
states = states / np.array(
|
145 |
+
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
|
146 |
+
)
|
147 |
+
actions = actions[step_id:step_id+self.CHUNK_SIZE] / np.array(
|
148 |
+
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
|
149 |
+
)
|
150 |
+
|
151 |
+
# Parse the state and action
|
152 |
+
state = states[step_id:step_id+1]
|
153 |
+
state_std = np.std(states, axis=0)
|
154 |
+
state_mean = np.mean(states, axis=0)
|
155 |
+
state_norm = np.sqrt(np.mean(states**2, axis=0))
|
156 |
+
|
157 |
+
if actions.shape[0] < self.CHUNK_SIZE:
|
158 |
+
# Pad the actions using the last action
|
159 |
+
actions = np.concatenate([
|
160 |
+
actions,
|
161 |
+
np.tile(actions[-1:], (self.CHUNK_SIZE-actions.shape[0], 1))
|
162 |
+
], axis=0)
|
163 |
+
|
164 |
+
# Fill the state/action into the unified vector
|
165 |
+
def fill_in_state(values):
|
166 |
+
uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM,))
|
167 |
+
uni_vec[..., TABLETOP_6D_INDICES] = values
|
168 |
+
return uni_vec
|
169 |
+
state = fill_in_state(state)
|
170 |
+
state_indicator = fill_in_state(np.ones_like(state_std))
|
171 |
+
state_std = fill_in_state(state_std)
|
172 |
+
state_mean = fill_in_state(state_mean)
|
173 |
+
state_norm = fill_in_state(state_norm)
|
174 |
+
# If action's format is different from state's,
|
175 |
+
# you may implement fill_in_action()
|
176 |
+
actions = fill_in_state(actions)
|
177 |
+
|
178 |
+
# Parse the images
|
179 |
+
def parse_img(key):
|
180 |
+
imgs = []
|
181 |
+
for i in range(max(step_id-self.IMG_HISORY_SIZE+1, 0), step_id+1):
|
182 |
+
img = f['observations']['images'][key][i]
|
183 |
+
# imgs.append(cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR))
|
184 |
+
imgs.append(img)
|
185 |
+
# print(imgs)
|
186 |
+
imgs = np.stack(imgs)
|
187 |
+
if imgs.shape[0] < self.IMG_HISORY_SIZE:
|
188 |
+
# Pad the images using the first image
|
189 |
+
imgs = np.concatenate([
|
190 |
+
np.tile(imgs[:1], (self.IMG_HISORY_SIZE-imgs.shape[0], 1, 1, 1)),
|
191 |
+
imgs
|
192 |
+
], axis=0)
|
193 |
+
return imgs
|
194 |
+
# `cam_high` is the external camera image
|
195 |
+
cam_high = parse_img('back')
|
196 |
+
# For step_id = first_idx - 1, the valid_len should be one
|
197 |
+
valid_len = min(step_id + 1, self.IMG_HISORY_SIZE)
|
198 |
+
cam_high_mask = np.array(
|
199 |
+
[False] * (self.IMG_HISORY_SIZE - valid_len) + [True] * valid_len
|
200 |
+
)
|
201 |
+
cam_left_wrist = parse_img('wrist_left')
|
202 |
+
cam_left_wrist_mask = cam_high_mask.copy()
|
203 |
+
cam_right_wrist = parse_img('wrist_right')
|
204 |
+
cam_right_wrist_mask = cam_high_mask.copy()
|
205 |
+
|
206 |
+
# print(cam_left_wrist is not None, cam_right_wrist is not None, cam_high is not None)
|
207 |
+
|
208 |
+
# Return the resulting sample
|
209 |
+
# For unavailable images, return zero-shape arrays, i.e., (IMG_HISORY_SIZE, 0, 0, 0)
|
210 |
+
# E.g., return np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0)) for the key "cam_left_wrist",
|
211 |
+
# if the left-wrist camera is unavailable on your robot
|
212 |
+
return True, {
|
213 |
+
"meta": meta,
|
214 |
+
"state": state,
|
215 |
+
"state_std": state_std,
|
216 |
+
"state_mean": state_mean,
|
217 |
+
"state_norm": state_norm,
|
218 |
+
"actions": actions,
|
219 |
+
"state_indicator": state_indicator,
|
220 |
+
"cam_high": cam_high,
|
221 |
+
"cam_high_mask": cam_high_mask,
|
222 |
+
"cam_left_wrist": cam_left_wrist,
|
223 |
+
"cam_left_wrist_mask": cam_left_wrist_mask,
|
224 |
+
"cam_right_wrist": cam_right_wrist,
|
225 |
+
"cam_right_wrist_mask": cam_right_wrist_mask
|
226 |
+
}
|
227 |
+
|
228 |
+
def parse_hdf5_file_state_only(self, file_path):
|
229 |
+
"""[Modify] Parse a hdf5 file to generate a state trajectory.
|
230 |
+
|
231 |
+
Args:
|
232 |
+
file_path (str): the path to the hdf5 file
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
valid (bool): whether the episode is valid, which is useful for filtering.
|
236 |
+
If False, this episode will be dropped.
|
237 |
+
dict: a dictionary containing the training sample,
|
238 |
+
{
|
239 |
+
"state": ndarray, # state[:], (T, STATE_DIM).
|
240 |
+
"action": ndarray, # action[:], (T, STATE_DIM).
|
241 |
+
} or None if the episode is invalid.
|
242 |
+
"""
|
243 |
+
with h5py.File(file_path, 'r') as f:
|
244 |
+
states = f['observations']['states']['ee_6d_pos'][:]
|
245 |
+
actions = f['actions']['ee_6d_pos'][:]
|
246 |
+
num_steps = states.shape[0]
|
247 |
+
|
248 |
+
step_id = np.random.randint(0, num_steps)
|
249 |
+
|
250 |
+
# Rescale gripper to [0, 1]
|
251 |
+
states = states / np.array(
|
252 |
+
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
|
253 |
+
)
|
254 |
+
actions = actions[step_id:step_id+self.CHUNK_SIZE] / np.array(
|
255 |
+
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
|
256 |
+
)
|
257 |
+
|
258 |
+
# Fill the state/action into the unified vector
|
259 |
+
def fill_in_state(values):
|
260 |
+
uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM,))
|
261 |
+
uni_vec[..., TABLETOP_6D_INDICES] = values
|
262 |
+
return uni_vec
|
263 |
+
state = fill_in_state(states)
|
264 |
+
action = fill_in_state(actions)
|
265 |
+
|
266 |
+
# Return the resulting sample
|
267 |
+
return True, {
|
268 |
+
"state": state,
|
269 |
+
"action": action
|
270 |
+
}
|
271 |
+
|
272 |
+
class AnubisHDF5VLADataset:
|
273 |
+
"""
|
274 |
+
This class is used to sample episodes from the embododiment dataset
|
275 |
+
stored in HDF5.
|
276 |
+
"""
|
277 |
+
def __init__(self, task_name) -> None:
|
278 |
+
# [Modify] The path to the HDF5 dataset directory
|
279 |
+
# Each HDF5 file contains one episode
|
280 |
+
dataset_name = task_name
|
281 |
+
HDF5_DIR = f"/data5/jellyho/anubis_hdf5/{dataset_name}/"
|
282 |
+
self.DATASET_NAME = dataset_name
|
283 |
+
|
284 |
+
self.file_paths = []
|
285 |
+
for root, _, files in os.walk(HDF5_DIR):
|
286 |
+
for filename in fnmatch.filter(files, '*.hdf5'):
|
287 |
+
file_path = os.path.join(root, filename)
|
288 |
+
self.file_paths.append(file_path)
|
289 |
+
|
290 |
+
# Load the config
|
291 |
+
with open('configs/base.yaml', 'r') as file:
|
292 |
+
config = yaml.safe_load(file)
|
293 |
+
self.CHUNK_SIZE = config['common']['action_chunk_size']
|
294 |
+
self.IMG_HISORY_SIZE = config['common']['img_history_size']
|
295 |
+
self.STATE_DIM = config['common']['state_dim']
|
296 |
+
|
297 |
+
# Get each episode's len
|
298 |
+
episode_lens = []
|
299 |
+
for file_path in self.file_paths:
|
300 |
+
valid, res = self.parse_hdf5_file_state_only(file_path)
|
301 |
+
_len = res['state'].shape[0] if valid else 0
|
302 |
+
episode_lens.append(_len)
|
303 |
+
self.episode_sample_weights = np.array(episode_lens) / np.sum(episode_lens)
|
304 |
+
|
305 |
+
def __len__(self):
|
306 |
+
return len(self.file_paths)
|
307 |
+
|
308 |
+
def get_dataset_name(self):
|
309 |
+
return self.DATASET_NAME
|
310 |
+
|
311 |
+
def get_item(self, index: int=None, state_only=False):
|
312 |
+
"""Get a training sample at a random timestep.
|
313 |
+
|
314 |
+
Args:
|
315 |
+
index (int, optional): the index of the episode.
|
316 |
+
If not provided, a random episode will be selected.
|
317 |
+
state_only (bool, optional): Whether to return only the state.
|
318 |
+
In this way, the sample will contain a complete trajectory rather
|
319 |
+
than a single timestep. Defaults to False.
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
sample (dict): a dictionary containing the training sample.
|
323 |
+
"""
|
324 |
+
while True:
|
325 |
+
if index is None:
|
326 |
+
file_path = np.random.choice(self.file_paths, p=self.episode_sample_weights)
|
327 |
+
else:
|
328 |
+
file_path = self.file_paths[index]
|
329 |
+
valid, sample = self.parse_hdf5_file(file_path) \
|
330 |
+
if not state_only else self.parse_hdf5_file_state_only(file_path)
|
331 |
+
if valid:
|
332 |
+
return sample
|
333 |
+
else:
|
334 |
+
index = np.random.randint(0, len(self.file_paths))
|
335 |
+
|
336 |
+
def parse_hdf5_file(self, file_path):
|
337 |
+
"""[Modify] Parse a hdf5 file to generate a training sample at
|
338 |
+
a random timestep.
|
339 |
+
|
340 |
+
Args:
|
341 |
+
file_path (str): the path to the hdf5 file
|
342 |
+
|
343 |
+
Returns:
|
344 |
+
valid (bool): whether the episode is valid, which is useful for filtering.
|
345 |
+
If False, this episode will be dropped.
|
346 |
+
dict: a dictionary containing the training sample,
|
347 |
+
{
|
348 |
+
"meta": {
|
349 |
+
"dataset_name": str, # the name of your dataset.
|
350 |
+
"#steps": int, # the number of steps in the episode,
|
351 |
+
# also the total timesteps.
|
352 |
+
"instruction": str # the language instruction for this episode.
|
353 |
+
},
|
354 |
+
"step_id": int, # the index of the sampled step,
|
355 |
+
# also the timestep t.
|
356 |
+
"state": ndarray, # state[t], (1, STATE_DIM).
|
357 |
+
"state_std": ndarray, # std(state[:]), (STATE_DIM,).
|
358 |
+
"state_mean": ndarray, # mean(state[:]), (STATE_DIM,).
|
359 |
+
"state_norm": ndarray, # norm(state[:]), (STATE_DIM,).
|
360 |
+
"actions": ndarray, # action[t:t+CHUNK_SIZE], (CHUNK_SIZE, STATE_DIM).
|
361 |
+
"state_indicator", ndarray, # indicates the validness of each dim, (STATE_DIM,).
|
362 |
+
"cam_high": ndarray, # external camera image, (IMG_HISORY_SIZE, H, W, 3)
|
363 |
+
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
|
364 |
+
"cam_high_mask": ndarray, # indicates the validness of each timestep, (IMG_HISORY_SIZE,) boolean array.
|
365 |
+
# For the first IMAGE_HISTORY_SIZE-1 timesteps, the mask should be False.
|
366 |
+
"cam_left_wrist": ndarray, # left wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
|
367 |
+
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
|
368 |
+
"cam_left_wrist_mask": ndarray,
|
369 |
+
"cam_right_wrist": ndarray, # right wrist camera image, (IMG_HISORY_SIZE, H, W, 3).
|
370 |
+
# or (IMG_HISORY_SIZE, 0, 0, 0) if unavailable.
|
371 |
+
# If only one wrist, make it right wrist, plz.
|
372 |
+
"cam_right_wrist_mask": ndarray
|
373 |
+
} or None if the episode is invalid.
|
374 |
+
"""
|
375 |
+
with h5py.File(file_path, 'r') as f:
|
376 |
+
states = f['observation']['eef_pose'][:]
|
377 |
+
actions = f['action']['eef_pose'][:]
|
378 |
+
num_steps = states.shape[0]
|
379 |
+
# [Optional] We drop too-short episode
|
380 |
+
if num_steps < 20:
|
381 |
+
return False, None
|
382 |
+
|
383 |
+
# We randomly sample a timestep
|
384 |
+
step_id = np.random.randint(0, num_steps)
|
385 |
+
|
386 |
+
# You can also use precomputed language embeddings (recommended)
|
387 |
+
if self.DATASET_NAME == 'aloha_box_into_pot_easy':
|
388 |
+
instruction = f['observations']['states']['language_instruction'][0].decode('utf-8')
|
389 |
+
else:
|
390 |
+
instruction = f"lang_embed/{self.DATASET_NAME}.pt"
|
391 |
+
|
392 |
+
# Assemble the meta
|
393 |
+
meta = {
|
394 |
+
"dataset_name": self.DATASET_NAME,
|
395 |
+
"#steps": num_steps,
|
396 |
+
"step_id": step_id,
|
397 |
+
"instruction": instruction
|
398 |
+
}
|
399 |
+
|
400 |
+
# Rescale gripper to [0, 1]
|
401 |
+
states = states / np.array(
|
402 |
+
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
|
403 |
+
)
|
404 |
+
actions = actions[step_id:step_id+self.CHUNK_SIZE] / np.array(
|
405 |
+
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
|
406 |
+
)
|
407 |
+
|
408 |
+
# Parse the state and action
|
409 |
+
state = states[step_id:step_id+1]
|
410 |
+
state_std = np.std(states, axis=0)
|
411 |
+
state_mean = np.mean(states, axis=0)
|
412 |
+
state_norm = np.sqrt(np.mean(states**2, axis=0))
|
413 |
+
|
414 |
+
if actions.shape[0] < self.CHUNK_SIZE:
|
415 |
+
# Pad the actions using the last action
|
416 |
+
actions = np.concatenate([
|
417 |
+
actions,
|
418 |
+
np.tile(actions[-1:], (self.CHUNK_SIZE-actions.shape[0], 1))
|
419 |
+
], axis=0)
|
420 |
+
|
421 |
+
# Fill the state/action into the unified vector
|
422 |
+
def fill_in_state(values):
|
423 |
+
uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM,))
|
424 |
+
uni_vec[..., TABLETOP_6D_INDICES] = values
|
425 |
+
return uni_vec
|
426 |
+
state = fill_in_state(state)
|
427 |
+
state_indicator = fill_in_state(np.ones_like(state_std))
|
428 |
+
state_std = fill_in_state(state_std)
|
429 |
+
state_mean = fill_in_state(state_mean)
|
430 |
+
state_norm = fill_in_state(state_norm)
|
431 |
+
# If action's format is different from state's,
|
432 |
+
# you may implement fill_in_action()
|
433 |
+
actions = fill_in_state(actions)
|
434 |
+
|
435 |
+
# Parse the images
|
436 |
+
def parse_img(key):
|
437 |
+
imgs = []
|
438 |
+
for i in range(max(step_id-self.IMG_HISORY_SIZE+1, 0), step_id+1):
|
439 |
+
img = f['observation'][key][i]
|
440 |
+
# imgs.append(cv2.imdecode(np.frombuffer(img, np.uint8), cv2.IMREAD_COLOR))
|
441 |
+
imgs.append(img)
|
442 |
+
# print(imgs)
|
443 |
+
imgs = np.stack(imgs)
|
444 |
+
if imgs.shape[0] < self.IMG_HISORY_SIZE:
|
445 |
+
# Pad the images using the first image
|
446 |
+
imgs = np.concatenate([
|
447 |
+
np.tile(imgs[:1], (self.IMG_HISORY_SIZE-imgs.shape[0], 1, 1, 1)),
|
448 |
+
imgs
|
449 |
+
], axis=0)
|
450 |
+
return imgs
|
451 |
+
# `cam_high` is the external camera image
|
452 |
+
cam_high = parse_img('agentview_image')
|
453 |
+
# For step_id = first_idx - 1, the valid_len should be one
|
454 |
+
valid_len = min(step_id + 1, self.IMG_HISORY_SIZE)
|
455 |
+
cam_high_mask = np.array(
|
456 |
+
[False] * (self.IMG_HISORY_SIZE - valid_len) + [True] * valid_len
|
457 |
+
)
|
458 |
+
cam_left_wrist = parse_img('wrist_left_image')
|
459 |
+
cam_left_wrist_mask = cam_high_mask.copy()
|
460 |
+
cam_right_wrist = parse_img('wrist_right_image')
|
461 |
+
cam_right_wrist_mask = cam_high_mask.copy()
|
462 |
+
|
463 |
+
# print(cam_left_wrist is not None, cam_right_wrist is not None, cam_high is not None)
|
464 |
+
|
465 |
+
# Return the resulting sample
|
466 |
+
# For unavailable images, return zero-shape arrays, i.e., (IMG_HISORY_SIZE, 0, 0, 0)
|
467 |
+
# E.g., return np.zeros((self.IMG_HISORY_SIZE, 0, 0, 0)) for the key "cam_left_wrist",
|
468 |
+
# if the left-wrist camera is unavailable on your robot
|
469 |
+
return True, {
|
470 |
+
"meta": meta,
|
471 |
+
"state": state,
|
472 |
+
"state_std": state_std,
|
473 |
+
"state_mean": state_mean,
|
474 |
+
"state_norm": state_norm,
|
475 |
+
"actions": actions,
|
476 |
+
"state_indicator": state_indicator,
|
477 |
+
"cam_high": cam_high,
|
478 |
+
"cam_high_mask": cam_high_mask,
|
479 |
+
"cam_left_wrist": cam_left_wrist,
|
480 |
+
"cam_left_wrist_mask": cam_left_wrist_mask,
|
481 |
+
"cam_right_wrist": cam_right_wrist,
|
482 |
+
"cam_right_wrist_mask": cam_right_wrist_mask
|
483 |
+
}
|
484 |
+
|
485 |
+
def parse_hdf5_file_state_only(self, file_path):
|
486 |
+
"""[Modify] Parse a hdf5 file to generate a state trajectory.
|
487 |
+
|
488 |
+
Args:
|
489 |
+
file_path (str): the path to the hdf5 file
|
490 |
+
|
491 |
+
Returns:
|
492 |
+
valid (bool): whether the episode is valid, which is useful for filtering.
|
493 |
+
If False, this episode will be dropped.
|
494 |
+
dict: a dictionary containing the training sample,
|
495 |
+
{
|
496 |
+
"state": ndarray, # state[:], (T, STATE_DIM).
|
497 |
+
"action": ndarray, # action[:], (T, STATE_DIM).
|
498 |
+
} or None if the episode is invalid.
|
499 |
+
"""
|
500 |
+
with h5py.File(file_path, 'r') as f:
|
501 |
+
states = f['observation']['eef_pose'][:]
|
502 |
+
actions = f['action']['eef_pose'][:]
|
503 |
+
num_steps = states.shape[0]
|
504 |
+
|
505 |
+
step_id = np.random.randint(0, num_steps)
|
506 |
+
|
507 |
+
# Rescale gripper to [0, 1]
|
508 |
+
states = states / np.array(
|
509 |
+
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
|
510 |
+
)
|
511 |
+
actions = actions[step_id:step_id+self.CHUNK_SIZE] / np.array(
|
512 |
+
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
|
513 |
+
)
|
514 |
+
|
515 |
+
# Fill the state/action into the unified vector
|
516 |
+
def fill_in_state(values):
|
517 |
+
uni_vec = np.zeros(values.shape[:-1] + (self.STATE_DIM,))
|
518 |
+
uni_vec[..., TABLETOP_6D_INDICES] = values
|
519 |
+
return uni_vec
|
520 |
+
state = fill_in_state(states)
|
521 |
+
action = fill_in_state(actions)
|
522 |
+
|
523 |
+
# Return the resulting sample
|
524 |
+
return True, {
|
525 |
+
"state": state,
|
526 |
+
"action": action
|
527 |
+
}
|
528 |
+
|
529 |
+
if __name__ == "__main__":
|
530 |
+
ds = TabletopHDF5VLADataset()
|
531 |
+
for i in range(len(ds)):
|
532 |
+
print(f"Processing episode {i}/{len(ds)}...")
|
533 |
+
ds.get_item(i)
|
data/preprocess.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
import tensorflow as tf
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
from data.preprocess_scripts import *
|
7 |
+
from configs.state_vec import STATE_VEC_IDX_MAPPING, STATE_VEC_LEN
|
8 |
+
from data.utils import capitalize_and_period
|
9 |
+
|
10 |
+
# The dataset without state
|
11 |
+
DATASET_NAMES_NO_STATE = [
|
12 |
+
'nyu_door_opening_surprising_effectiveness',
|
13 |
+
"usc_cloth_sim_converted_externally_to_rlds",
|
14 |
+
'cmu_franka_exploration_dataset_converted_externally_to_rlds',
|
15 |
+
'imperialcollege_sawyer_wrist_cam'
|
16 |
+
]
|
17 |
+
|
18 |
+
# Read the image keys of each dataset
|
19 |
+
with open('configs/dataset_img_keys.json', 'r') as file:
|
20 |
+
IMAGE_KEYS = json.load(file)
|
21 |
+
# Read the config
|
22 |
+
with open('configs/base.yaml', 'r') as file:
|
23 |
+
config = yaml.safe_load(file)
|
24 |
+
|
25 |
+
|
26 |
+
def assemble_state_vec(arm_concat: tf.Tensor, arm_format: str,
|
27 |
+
base_concat=None, base_format=None) -> tf.Tensor:
|
28 |
+
"""
|
29 |
+
Assemble the state/action vector from the arm and base.
|
30 |
+
"""
|
31 |
+
state_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32)
|
32 |
+
mask_vec = tf.zeros(STATE_VEC_LEN, dtype=tf.float32)
|
33 |
+
|
34 |
+
# Assemble the arm state
|
35 |
+
arm_concat = tf.cast(arm_concat, tf.float32)
|
36 |
+
arm_format = arm_format.split(',')
|
37 |
+
# Use the scatter_nd to avoid the duplicate indices
|
38 |
+
state_vec = tf.tensor_scatter_nd_update(
|
39 |
+
state_vec,
|
40 |
+
[[STATE_VEC_IDX_MAPPING[name]] for name in arm_format],
|
41 |
+
arm_concat
|
42 |
+
)
|
43 |
+
mask_vec = tf.tensor_scatter_nd_update(
|
44 |
+
mask_vec,
|
45 |
+
[[STATE_VEC_IDX_MAPPING[name]] for name in arm_format],
|
46 |
+
tf.ones(len(arm_format), dtype=tf.float32)
|
47 |
+
)
|
48 |
+
|
49 |
+
# Assemble the base state if exists
|
50 |
+
if base_concat is not None:
|
51 |
+
base_concat = tf.cast(base_concat, tf.float32)
|
52 |
+
base_format = base_format.split(',')
|
53 |
+
state_vec = tf.tensor_scatter_nd_update(
|
54 |
+
state_vec,
|
55 |
+
[[STATE_VEC_IDX_MAPPING[name]] for name in base_format],
|
56 |
+
base_concat
|
57 |
+
)
|
58 |
+
mask_vec = tf.tensor_scatter_nd_update(
|
59 |
+
mask_vec,
|
60 |
+
[[STATE_VEC_IDX_MAPPING[name]] for name in base_format],
|
61 |
+
tf.ones(len(base_format), dtype=tf.float32)
|
62 |
+
)
|
63 |
+
return state_vec, mask_vec
|
64 |
+
|
65 |
+
|
66 |
+
@tf.autograph.experimental.do_not_convert
|
67 |
+
def _generate_json_state_agilex(episode: dict, dataset_name: str):
|
68 |
+
"""
|
69 |
+
Generate the json dict and state for a given episode.
|
70 |
+
"""
|
71 |
+
# Load some constants from the config
|
72 |
+
IMG_HISTORY_SIZE = config['common']['img_history_size']
|
73 |
+
if IMG_HISTORY_SIZE < 1:
|
74 |
+
raise ValueError("Config `img_history_size` must be at least 1.")
|
75 |
+
ACTION_CHUNK_SIZE = config['common']['action_chunk_size']
|
76 |
+
if ACTION_CHUNK_SIZE < 1:
|
77 |
+
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
78 |
+
|
79 |
+
# Initialize the episode_metadata
|
80 |
+
episode_metadata = {
|
81 |
+
'dataset_name': dataset_name,
|
82 |
+
'#steps': 0,
|
83 |
+
'instruction': None
|
84 |
+
}
|
85 |
+
|
86 |
+
# Check whether this episode has an 'END'
|
87 |
+
base_act = None
|
88 |
+
last_base_act = None
|
89 |
+
episode_states = []
|
90 |
+
episode_acts = []
|
91 |
+
episode_masks = []
|
92 |
+
has_base = None
|
93 |
+
for step_id, step in enumerate(iter(episode['steps'])):
|
94 |
+
# Parse the action
|
95 |
+
action = step['action']
|
96 |
+
if has_base is None:
|
97 |
+
has_base = 'base_concat' in action
|
98 |
+
if has_base:
|
99 |
+
base_act = action['base_concat']
|
100 |
+
|
101 |
+
# Parse the state
|
102 |
+
state = step['observation']
|
103 |
+
|
104 |
+
arm_format = state['format'].numpy().decode('utf-8')
|
105 |
+
base_format = None
|
106 |
+
if has_base:
|
107 |
+
act_format = action['format'].numpy().decode('utf-8')
|
108 |
+
base_formate_idx = act_format.find('base')
|
109 |
+
base_format = act_format[base_formate_idx:]
|
110 |
+
|
111 |
+
arm_state = state['arm_concat']
|
112 |
+
base_state = None
|
113 |
+
if has_base:
|
114 |
+
if last_base_act is None:
|
115 |
+
base_state = base_act * 0
|
116 |
+
else:
|
117 |
+
base_state = last_base_act
|
118 |
+
last_base_act = base_act
|
119 |
+
|
120 |
+
# Assemble the state vector
|
121 |
+
state_vec, mask_vec = assemble_state_vec(
|
122 |
+
arm_state, arm_format, base_state, base_format)
|
123 |
+
|
124 |
+
|
125 |
+
act_vec, mask_vec = assemble_state_vec(
|
126 |
+
action['arm_concat'], arm_format, base_state, base_format
|
127 |
+
)
|
128 |
+
|
129 |
+
episode_states.append(state_vec)
|
130 |
+
episode_masks.append(mask_vec)
|
131 |
+
episode_acts.append(act_vec)
|
132 |
+
|
133 |
+
# Parse the task instruction
|
134 |
+
instr = step['observation']['natural_language_instruction']
|
135 |
+
instr = instr.numpy().decode('utf-8')
|
136 |
+
instr = capitalize_and_period(instr)
|
137 |
+
|
138 |
+
# Write to the episode_metadata
|
139 |
+
if episode_metadata['instruction'] is None:
|
140 |
+
episode_metadata['instruction'] = instr
|
141 |
+
|
142 |
+
episode_metadata['#steps'] = step_id
|
143 |
+
|
144 |
+
episode_states = tf.stack(episode_states)
|
145 |
+
episode_masks = tf.stack(episode_masks)
|
146 |
+
episode_acts = tf.stack(episode_acts)
|
147 |
+
|
148 |
+
return episode_metadata, episode_states, episode_masks, episode_acts
|
149 |
+
|
150 |
+
|
151 |
+
@tf.autograph.experimental.do_not_convert
|
152 |
+
def _generate_json_state(episode: dict, dataset_name: str):
|
153 |
+
"""
|
154 |
+
Generate the json dict and state for a given episode.
|
155 |
+
"""
|
156 |
+
# Load some constants from the config
|
157 |
+
IMG_HISTORY_SIZE = config['common']['img_history_size']
|
158 |
+
if IMG_HISTORY_SIZE < 1:
|
159 |
+
raise ValueError("Config `img_history_size` must be at least 1.")
|
160 |
+
ACTION_CHUNK_SIZE = config['common']['action_chunk_size']
|
161 |
+
if ACTION_CHUNK_SIZE < 1:
|
162 |
+
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
163 |
+
|
164 |
+
# Initialize the episode_metadata
|
165 |
+
episode_metadata = {
|
166 |
+
'dataset_name': dataset_name,
|
167 |
+
'#steps': 0,
|
168 |
+
'instruction': None
|
169 |
+
}
|
170 |
+
|
171 |
+
# Check whether this episode has an 'END'
|
172 |
+
base_act = None
|
173 |
+
last_base_act = None
|
174 |
+
episode_states = []
|
175 |
+
episode_masks = []
|
176 |
+
has_base = None
|
177 |
+
for step_id, step in enumerate(iter(episode['steps'])):
|
178 |
+
# Parse the action
|
179 |
+
action = step['action']
|
180 |
+
if has_base is None:
|
181 |
+
has_base = 'base_concat' in action
|
182 |
+
if has_base:
|
183 |
+
base_act = action['base_concat']
|
184 |
+
|
185 |
+
# Parse the state
|
186 |
+
state = step['observation']
|
187 |
+
|
188 |
+
arm_format = state['format'].numpy().decode('utf-8')
|
189 |
+
base_format = None
|
190 |
+
if has_base:
|
191 |
+
act_format = action['format'].numpy().decode('utf-8')
|
192 |
+
base_formate_idx = act_format.find('base')
|
193 |
+
base_format = act_format[base_formate_idx:]
|
194 |
+
|
195 |
+
arm_state = state['arm_concat']
|
196 |
+
base_state = None
|
197 |
+
if has_base:
|
198 |
+
if last_base_act is None:
|
199 |
+
base_state = base_act * 0
|
200 |
+
else:
|
201 |
+
base_state = last_base_act
|
202 |
+
last_base_act = base_act
|
203 |
+
|
204 |
+
# Assemble the state vector
|
205 |
+
state_vec, mask_vec = assemble_state_vec(
|
206 |
+
arm_state, arm_format, base_state, base_format)
|
207 |
+
|
208 |
+
episode_states.append(state_vec)
|
209 |
+
episode_masks.append(mask_vec)
|
210 |
+
|
211 |
+
# Parse the task instruction
|
212 |
+
instr = step['observation']['natural_language_instruction']
|
213 |
+
instr = instr.numpy().decode('utf-8')
|
214 |
+
instr = capitalize_and_period(instr)
|
215 |
+
|
216 |
+
# Write to the episode_metadata
|
217 |
+
if episode_metadata['instruction'] is None:
|
218 |
+
episode_metadata['instruction'] = instr
|
219 |
+
|
220 |
+
episode_metadata['#steps'] = step_id
|
221 |
+
episode_states = tf.stack(episode_states)
|
222 |
+
episode_masks = tf.stack(episode_masks)
|
223 |
+
|
224 |
+
return episode_metadata, episode_states, episode_masks
|
225 |
+
|
226 |
+
|
227 |
+
@tf.autograph.experimental.do_not_convert
|
228 |
+
def _generate_json_state_nostate_ds(episode: dict, dataset_name: str):
|
229 |
+
"""
|
230 |
+
Generate the json dict and state for an episode in the dataset without state.
|
231 |
+
If not state, we use the last action as current state.
|
232 |
+
"""
|
233 |
+
# Load some constants from the config
|
234 |
+
IMG_HISTORY_SIZE = config['common']['img_history_size']
|
235 |
+
if IMG_HISTORY_SIZE < 1:
|
236 |
+
raise ValueError("Config `img_history_size` must be at least 1.")
|
237 |
+
ACTION_CHUNK_SIZE = config['common']['action_chunk_size']
|
238 |
+
if ACTION_CHUNK_SIZE < 1:
|
239 |
+
raise ValueError("Config `action_chunk_size` must be at least 1.")
|
240 |
+
|
241 |
+
# Initialize the episode_metadata
|
242 |
+
episode_metadata = {
|
243 |
+
'dataset_name': dataset_name,
|
244 |
+
'#steps': 0,
|
245 |
+
'instruction': None
|
246 |
+
}
|
247 |
+
|
248 |
+
last_base_act = None
|
249 |
+
last_arm_act = None
|
250 |
+
episode_states = []
|
251 |
+
episode_masks = []
|
252 |
+
has_base = None
|
253 |
+
for step_id, step in enumerate(iter(episode['steps'])):
|
254 |
+
# Parse the action
|
255 |
+
action = step['action']
|
256 |
+
if has_base is None:
|
257 |
+
has_base = 'base_concat' in action
|
258 |
+
if has_base:
|
259 |
+
base_act = action['base_concat']
|
260 |
+
if last_base_act is None:
|
261 |
+
last_base_act = base_act * 0 # Initialize
|
262 |
+
|
263 |
+
# Parse the arm action
|
264 |
+
arm_act = action['arm_concat']
|
265 |
+
if last_arm_act is None:
|
266 |
+
last_arm_act = arm_act * 0 # Initialize
|
267 |
+
|
268 |
+
# Parse the act format
|
269 |
+
# Action format as the state format
|
270 |
+
act_format = action['format'].numpy().decode('utf-8')
|
271 |
+
|
272 |
+
# Assemble the state vector
|
273 |
+
if has_base:
|
274 |
+
last_act_concat = tf.concat([last_arm_act, last_base_act], axis=0)
|
275 |
+
else:
|
276 |
+
last_act_concat = last_arm_act
|
277 |
+
state_vec, mask_vec = assemble_state_vec(
|
278 |
+
last_act_concat, act_format)
|
279 |
+
|
280 |
+
episode_states.append(state_vec)
|
281 |
+
episode_masks.append(mask_vec)
|
282 |
+
|
283 |
+
# Parse the task instruction
|
284 |
+
instr = step['observation']['natural_language_instruction']
|
285 |
+
instr = instr.numpy().decode('utf-8')
|
286 |
+
instr = capitalize_and_period(instr)
|
287 |
+
|
288 |
+
# Write to the episode_metadata
|
289 |
+
if episode_metadata['instruction'] is None:
|
290 |
+
episode_metadata['instruction'] = instr
|
291 |
+
|
292 |
+
# Update the last_arm_act and last_base_act
|
293 |
+
last_arm_act = arm_act
|
294 |
+
if has_base:
|
295 |
+
last_base_act = base_act
|
296 |
+
|
297 |
+
episode_metadata['#steps'] = step_id
|
298 |
+
episode_states = tf.stack(episode_states)
|
299 |
+
episode_masks = tf.stack(episode_masks)
|
300 |
+
|
301 |
+
return episode_metadata, episode_states, episode_masks
|
302 |
+
|
303 |
+
|
304 |
+
@tf.autograph.experimental.do_not_convert
|
305 |
+
def generate_json_state(episode: dict, dataset_name: str):
|
306 |
+
"""
|
307 |
+
Generate the json dict and state for an episode.
|
308 |
+
"""
|
309 |
+
if isinstance(dataset_name, tf.Tensor):
|
310 |
+
dataset_name = dataset_name.numpy().decode('utf-8')
|
311 |
+
|
312 |
+
# Process each step in the episode
|
313 |
+
episode['steps'] = episode['steps'].map(
|
314 |
+
globals()[dataset_name].process_step,
|
315 |
+
)
|
316 |
+
|
317 |
+
if dataset_name == "agilex":
|
318 |
+
return _generate_json_state_agilex(episode, dataset_name)
|
319 |
+
|
320 |
+
if dataset_name in DATASET_NAMES_NO_STATE:
|
321 |
+
return _generate_json_state_nostate_ds(episode, dataset_name)
|
322 |
+
|
323 |
+
return _generate_json_state(episode, dataset_name)
|
data/preprocess_scripts/__init__.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import fractal20220817_data
|
2 |
+
from . import bridge
|
3 |
+
from . import jaco_play
|
4 |
+
from . import nyu_door_opening_surprising_effectiveness
|
5 |
+
from . import taco_play
|
6 |
+
from . import berkeley_cable_routing
|
7 |
+
from . import roboturk
|
8 |
+
from . import viola
|
9 |
+
from . import berkeley_autolab_ur5
|
10 |
+
from . import toto
|
11 |
+
from . import columbia_cairlab_pusht_real
|
12 |
+
from . import stanford_kuka_multimodal_dataset_converted_externally_to_rlds
|
13 |
+
from . import nyu_rot_dataset_converted_externally_to_rlds
|
14 |
+
from . import austin_buds_dataset_converted_externally_to_rlds
|
15 |
+
from . import nyu_franka_play_dataset_converted_externally_to_rlds
|
16 |
+
from . import cmu_franka_exploration_dataset_converted_externally_to_rlds
|
17 |
+
from . import kuka
|
18 |
+
from . import utokyo_xarm_bimanual_converted_externally_to_rlds
|
19 |
+
from . import maniskill_dataset_converted_externally_to_rlds
|
20 |
+
from . import stanford_hydra_dataset_converted_externally_to_rlds
|
21 |
+
from . import ucsd_kitchen_dataset_converted_externally_to_rlds
|
22 |
+
from . import ucsd_pick_and_place_dataset_converted_externally_to_rlds
|
23 |
+
from . import austin_sailor_dataset_converted_externally_to_rlds
|
24 |
+
from . import austin_sirius_dataset_converted_externally_to_rlds
|
25 |
+
from . import bc_z
|
26 |
+
from . import usc_cloth_sim_converted_externally_to_rlds
|
27 |
+
from . import utokyo_pr2_opening_fridge_converted_externally_to_rlds
|
28 |
+
from . import utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds
|
29 |
+
from . import utokyo_xarm_pick_and_place_converted_externally_to_rlds
|
30 |
+
from . import berkeley_mvp_converted_externally_to_rlds
|
31 |
+
from . import berkeley_rpt_converted_externally_to_rlds
|
32 |
+
from . import kaist_nonprehensile_converted_externally_to_rlds
|
33 |
+
from . import stanford_mask_vit_converted_externally_to_rlds
|
34 |
+
from . import tokyo_u_lsmo_converted_externally_to_rlds
|
35 |
+
from . import dlr_sara_pour_converted_externally_to_rlds
|
36 |
+
from . import dlr_sara_grid_clamp_converted_externally_to_rlds
|
37 |
+
from . import dlr_edan_shared_control_converted_externally_to_rlds
|
38 |
+
from . import asu_table_top_converted_externally_to_rlds
|
39 |
+
from . import stanford_robocook_converted_externally_to_rlds
|
40 |
+
from . import roboturk_real_laundrylayout
|
41 |
+
from . import roboturk_real_towercreation
|
42 |
+
from . import roboturk_real_objectsearch
|
43 |
+
from . import robomimic_lift_ph
|
44 |
+
from . import robomimic_can_ph
|
45 |
+
from . import robomimic_square_ph
|
46 |
+
from . import robomimic_transport_ph
|
47 |
+
from . import robomimic_tool_hang_ph
|
48 |
+
from . import eth_agent_affordances
|
49 |
+
from . import imperialcollege_sawyer_wrist_cam
|
50 |
+
from . import iamlab_cmu_pickup_insert_converted_externally_to_rlds
|
51 |
+
from . import uiuc_d3field
|
52 |
+
from . import utaustin_mutex
|
53 |
+
from . import berkeley_fanuc_manipulation
|
54 |
+
from . import cmu_play_fusion
|
55 |
+
from . import cmu_stretch
|
56 |
+
from . import berkeley_gnm_recon
|
57 |
+
from . import berkeley_gnm_cory_hall
|
58 |
+
from . import berkeley_gnm_sac_son
|
59 |
+
from . import language_table
|
60 |
+
from . import furniture_bench_dataset_converted_externally_to_rlds
|
61 |
+
from . import robo_net
|
62 |
+
from . import bridgev2
|
63 |
+
from . import aloha_mobile
|
64 |
+
from . import aloha_static
|
65 |
+
from . import droid
|
66 |
+
from . import fmb
|
67 |
+
from . import dobbe
|
68 |
+
from . import qut_dexterous_manpulation
|
69 |
+
from . import roboset
|
70 |
+
from . import agilex
|
71 |
+
from . import rh20t
|
72 |
+
from . import calvin
|
73 |
+
from . import aloha_dish_drainer, aloha_handover_box, aloha_shoes_table, aloha_lift_box, aloha_box_into_pot
|
data/preprocess_scripts/aloha_shoes_table.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, euler_to_rotation_matrix, rotation_matrix_to_ortho6d
|
4 |
+
|
5 |
+
|
6 |
+
def process_step(step: dict) -> dict:
|
7 |
+
"""
|
8 |
+
Unify the action format and clean the task instruction.
|
9 |
+
|
10 |
+
DO NOT use python list, use tf.TensorArray instead.
|
11 |
+
"""
|
12 |
+
# Convert raw action to our action
|
13 |
+
action_dict = step['action']
|
14 |
+
# Concatenate the action
|
15 |
+
step['action'] = {}
|
16 |
+
action = step['action']
|
17 |
+
action['arm_concat'] = action_dict['ee_6d_pos']
|
18 |
+
|
19 |
+
# Write the action format
|
20 |
+
action['format'] = tf.constant(
|
21 |
+
"left_eef_pos_x,left_eef_pos_y,left_eef_pos_z,left_eef_angle_0,left_eef_angle_1,left_eef_angle_2,left_eef_angle_3,left_eef_angle_4,left_eef_angle_5,left_gripper_open,right_eef_pos_x,right_eef_pos_y,right_eef_pos_z,right_eef_angle_0,right_eef_angle_1,right_eef_angle_2,right_eef_angle_3,right_eef_angle_4,right_eef_angle_5,right_gripper_open"
|
22 |
+
)
|
23 |
+
|
24 |
+
# Convert raw state to our state
|
25 |
+
# Robot state
|
26 |
+
state_dict = step['observation']['state']
|
27 |
+
state = {}
|
28 |
+
state['arm_concat'] = state_dict
|
29 |
+
|
30 |
+
# Write the state format
|
31 |
+
state['format'] = tf.constant(
|
32 |
+
"left_eef_pos_x,left_eef_pos_y,left_eef_pos_z,left_eef_angle_0,left_eef_angle_1,left_eef_angle_2,left_eef_angle_3,left_eef_angle_4,left_eef_angle_5,left_gripper_open,right_eef_pos_x,right_eef_pos_y,right_eef_pos_z,right_eef_angle_0,right_eef_angle_1,right_eef_angle_2,right_eef_angle_3,right_eef_angle_4,right_eef_angle_5,right_gripper_open"
|
33 |
+
)
|
34 |
+
# Clean the task instruction
|
35 |
+
# Define the replacements (old, new) as a dictionary
|
36 |
+
replacements = {
|
37 |
+
'_': ' ',
|
38 |
+
'1f': ' ',
|
39 |
+
'4f': ' ',
|
40 |
+
'-': ' ',
|
41 |
+
'50': ' ',
|
42 |
+
'55': ' ',
|
43 |
+
'56': ' ',
|
44 |
+
|
45 |
+
}
|
46 |
+
instr = step['language_instruction']
|
47 |
+
# instr = clean_task_instruction(instr, replacements)
|
48 |
+
step['observation'] = state
|
49 |
+
step['observation']['natural_language_instruction'] = instr
|
50 |
+
|
51 |
+
return step
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
pass
|
data/preprocess_scripts/austin_buds_dataset_converted_externally_to_rlds.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, euler_to_quaternion, rotation_matrix_to_ortho6d
|
4 |
+
|
5 |
+
def process_step(step: dict) -> dict:
|
6 |
+
"""
|
7 |
+
Unify the action format and clean the task instruction.
|
8 |
+
|
9 |
+
DO NOT use python list, use tf.TensorArray instead.
|
10 |
+
"""
|
11 |
+
# Convert raw action to our action
|
12 |
+
|
13 |
+
origin_action = step['action']
|
14 |
+
step['action']={}
|
15 |
+
action=step['action']
|
16 |
+
action['terminate'] = step['is_terminal']
|
17 |
+
|
18 |
+
eef_delta_pos = origin_action[:3]
|
19 |
+
eef_ang=origin_action[3:6]
|
20 |
+
eef_ang = euler_to_quaternion(eef_ang)
|
21 |
+
# gripper_open: -1-open, 1-closed
|
22 |
+
grip_open=tf.where(tf.equal(origin_action[6:],tf.constant(-1.0)),tf.constant(1.0),tf.constant(0.0))
|
23 |
+
|
24 |
+
# No base found
|
25 |
+
|
26 |
+
# Concatenate the action
|
27 |
+
action['arm_concat'] = tf.concat([eef_delta_pos,eef_ang,grip_open],axis=0)
|
28 |
+
|
29 |
+
# Write the action format
|
30 |
+
action['format'] = tf.constant(
|
31 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
|
32 |
+
|
33 |
+
# Convert raw state to our state
|
34 |
+
state = step['observation']
|
35 |
+
# Concatenate the state
|
36 |
+
eef_mat = tf.transpose(tf.reshape(state['state'][8:], (4, 4)))
|
37 |
+
eef_pos = eef_mat[:3, 3]
|
38 |
+
rotaion_matrix = eef_mat[:3, :3]
|
39 |
+
eef_ang = rotation_matrix_to_ortho6d(rotaion_matrix)
|
40 |
+
joint_pos = state['state'][:7]
|
41 |
+
grip_open = state['state'][7:8] * 12.5 # rescale to [0, 1]
|
42 |
+
state['arm_concat'] = tf.concat([joint_pos,grip_open,eef_pos,eef_ang],axis=0)
|
43 |
+
|
44 |
+
# Write the state format
|
45 |
+
state['format'] = tf.constant(
|
46 |
+
"arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,gripper_joint_0_pos,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
|
47 |
+
|
48 |
+
# Clean the task instruction
|
49 |
+
# Define the replacements (old, new) as a dictionary
|
50 |
+
replacements = {
|
51 |
+
'_': ' ',
|
52 |
+
'1f': ' ',
|
53 |
+
'4f': ' ',
|
54 |
+
'-': ' ',
|
55 |
+
'50': ' ',
|
56 |
+
'55': ' ',
|
57 |
+
'56': ' ',
|
58 |
+
|
59 |
+
}
|
60 |
+
instr = step['language_instruction']
|
61 |
+
instr = clean_task_instruction(instr, replacements)
|
62 |
+
step['observation']['natural_language_instruction'] = instr
|
63 |
+
|
64 |
+
return step
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
import tensorflow_datasets as tfds
|
69 |
+
from data.utils import dataset_to_path
|
70 |
+
|
71 |
+
DATASET_DIR = 'data/datasets/openx_embod'
|
72 |
+
DATASET_NAME = 'austin_buds_dataset_converted_externally_to_rlds'
|
73 |
+
# Load the dataset
|
74 |
+
dataset = tfds.builder_from_directory(
|
75 |
+
builder_dir=dataset_to_path(
|
76 |
+
DATASET_NAME, DATASET_DIR))
|
77 |
+
dataset = dataset.as_dataset(split='all')
|
78 |
+
|
79 |
+
# Inspect the dataset
|
80 |
+
for episode in dataset:
|
81 |
+
for step in episode['steps']:
|
82 |
+
print(step)
|
data/preprocess_scripts/berkeley_autolab_ur5.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, euler_to_quaternion, \
|
4 |
+
quaternion_to_rotation_matrix, rotation_matrix_to_ortho6d
|
5 |
+
|
6 |
+
|
7 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
8 |
+
"""
|
9 |
+
Convert terminate action to a boolean, where True means terminate.
|
10 |
+
"""
|
11 |
+
return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
|
12 |
+
|
13 |
+
|
14 |
+
def process_step(step: dict) -> dict:
|
15 |
+
"""
|
16 |
+
Unify the action format and clean the task instruction.
|
17 |
+
|
18 |
+
DO NOT use python list, use tf.TensorArray instead.
|
19 |
+
"""
|
20 |
+
# Convert raw action to our action
|
21 |
+
action = step['action']
|
22 |
+
action['terminate'] = terminate_act_to_bool(action['terminate_episode'])
|
23 |
+
eef_delta_pos = action['world_vector']
|
24 |
+
eef_ang = action['rotation_delta']
|
25 |
+
eef_ang = euler_to_quaternion(eef_ang)
|
26 |
+
|
27 |
+
# Ignore action['gripper_open']: 1 if close gripper, -1 if open gripper, 0 if no change.
|
28 |
+
|
29 |
+
# No base found
|
30 |
+
|
31 |
+
# Concatenate the action
|
32 |
+
arm_action = tf.concat([eef_delta_pos, eef_ang], axis=0)
|
33 |
+
action['arm_concat'] = arm_action
|
34 |
+
|
35 |
+
# Write the action format
|
36 |
+
action['format'] = tf.constant(
|
37 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w")
|
38 |
+
|
39 |
+
# Convert raw state to our state
|
40 |
+
state = step['observation']
|
41 |
+
# state['robot_state']:[joint0, joint1, joint2, joint3, joint4, joint5, x,y,z, qx,qy,qz,qw, gripper_is_closed, action_blocked]
|
42 |
+
robot_state = state['robot_state']
|
43 |
+
joint_pos=robot_state[:6]
|
44 |
+
eef_pos = robot_state[6:9]
|
45 |
+
eef_quat = robot_state[9:13]
|
46 |
+
eef_ang = quaternion_to_rotation_matrix(eef_quat)
|
47 |
+
eef_ang = rotation_matrix_to_ortho6d(eef_ang)
|
48 |
+
# gripper_is_closed is binary: 0 = fully open; 1 = fully closed
|
49 |
+
grip_closed = robot_state[13:14]
|
50 |
+
grip_open= 1-grip_closed
|
51 |
+
# action_blocked is binary: 0 = not blocked; 1 = blocked
|
52 |
+
# action_blocked = robot_state[14:15]
|
53 |
+
|
54 |
+
# Concatenate the state
|
55 |
+
state['arm_concat'] = tf.concat([joint_pos, grip_open,eef_pos,eef_ang], axis=0)
|
56 |
+
|
57 |
+
# Write the state format
|
58 |
+
state['format'] = tf.constant(
|
59 |
+
"arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,gripper_open,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
|
60 |
+
|
61 |
+
# Clean the task instruction
|
62 |
+
# Define the replacements (old, new) as a dictionary
|
63 |
+
replacements = {
|
64 |
+
'_': ' ',
|
65 |
+
'1f': ' ',
|
66 |
+
'4f': ' ',
|
67 |
+
'-': ' ',
|
68 |
+
'50': ' ',
|
69 |
+
'55': ' ',
|
70 |
+
'56': ' ',
|
71 |
+
|
72 |
+
}
|
73 |
+
instr = step['observation']['natural_language_instruction']
|
74 |
+
instr = clean_task_instruction(instr, replacements)
|
75 |
+
step['observation']['natural_language_instruction'] = instr
|
76 |
+
|
77 |
+
return step
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
import tensorflow_datasets as tfds
|
82 |
+
from data.utils import dataset_to_path
|
83 |
+
|
84 |
+
DATASET_DIR = 'data/datasets/openx_embod'
|
85 |
+
DATASET_NAME = 'berkeley_autolab_ur5'
|
86 |
+
# Load the dataset
|
87 |
+
dataset = tfds.builder_from_directory(
|
88 |
+
builder_dir=dataset_to_path(
|
89 |
+
DATASET_NAME, DATASET_DIR))
|
90 |
+
dataset = dataset.as_dataset(split='all')
|
91 |
+
|
92 |
+
# Inspect the dataset
|
93 |
+
for episode in dataset:
|
94 |
+
for step in episode['steps']:
|
95 |
+
print(step)
|
data/preprocess_scripts/berkeley_cable_routing.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, quaternion_to_rotation_matrix, rotation_matrix_to_ortho6d
|
4 |
+
|
5 |
+
|
6 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
7 |
+
"""
|
8 |
+
Convert terminate action to a boolean, where True means terminate.
|
9 |
+
"""
|
10 |
+
return tf.equal(terminate_act, tf.constant(1.0, dtype=tf.float32))
|
11 |
+
|
12 |
+
|
13 |
+
def process_step(step: dict) -> dict:
|
14 |
+
"""
|
15 |
+
Unify the action format and clean the task instruction.
|
16 |
+
|
17 |
+
DO NOT use python list, use tf.TensorArray instead.
|
18 |
+
"""
|
19 |
+
# Convert raw action to our action
|
20 |
+
action = step['action']
|
21 |
+
action['terminate'] = terminate_act_to_bool(action['terminate_episode'])
|
22 |
+
|
23 |
+
eef_delta_pos = action['world_vector']
|
24 |
+
eef_ang=action['rotation_delta']
|
25 |
+
|
26 |
+
# No gripper_open found
|
27 |
+
# No base found
|
28 |
+
|
29 |
+
# Concatenate the action
|
30 |
+
arm_action=tf.concat([eef_delta_pos,eef_ang],axis=0)
|
31 |
+
action['arm_concat']=arm_action
|
32 |
+
#base_action = tf.concat([base_pos, base_ang], axis=0)
|
33 |
+
#action['base_concat'] = base_action
|
34 |
+
|
35 |
+
# Write the action format
|
36 |
+
action['format']=tf.constant("eef_vel_x,eef_vel_y,eef_vel_z,eef_angular_vel_roll,eef_angular_vel_pitch,eef_angular_vel_yaw")
|
37 |
+
|
38 |
+
# Convert raw state to our state
|
39 |
+
state = step['observation']
|
40 |
+
eef_pos = state['robot_state'][:3]
|
41 |
+
eef_ang = quaternion_to_rotation_matrix(state['robot_state'][3:])
|
42 |
+
eef_ang = rotation_matrix_to_ortho6d(eef_ang)
|
43 |
+
|
44 |
+
# Concatenate the state
|
45 |
+
state['arm_concat']=tf.concat([eef_pos,eef_ang],axis=0)
|
46 |
+
|
47 |
+
# Write the state format
|
48 |
+
state['format'] = tf.constant(
|
49 |
+
"eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
|
50 |
+
|
51 |
+
# Define the task instruction
|
52 |
+
step['observation']['natural_language_instruction'] = tf.constant(
|
53 |
+
"Route cable through the tight-fitting clip mounted on the table.")
|
54 |
+
|
55 |
+
return step
|
56 |
+
|
57 |
+
|
58 |
+
if __name__ == "__main__":
|
59 |
+
import tensorflow_datasets as tfds
|
60 |
+
from data.utils import dataset_to_path
|
61 |
+
|
62 |
+
DATASET_DIR = 'data/datasets/openx_embod/'
|
63 |
+
DATASET_NAME = 'berkeley_cable_routing'
|
64 |
+
# Load the dataset
|
65 |
+
dataset = tfds.builder_from_directory(
|
66 |
+
builder_dir=dataset_to_path(
|
67 |
+
DATASET_NAME, DATASET_DIR))
|
68 |
+
dataset = dataset.as_dataset(split='all')
|
69 |
+
|
70 |
+
# Inspect the dataset
|
71 |
+
for episode in dataset:
|
72 |
+
for step in episode['steps']:
|
73 |
+
print(step)
|
data/preprocess_scripts/berkeley_gnm_sac_son.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, euler_to_quaternion, euler_to_rotation_matrix, \
|
4 |
+
rotation_matrix_to_ortho6d
|
5 |
+
|
6 |
+
def process_step(step: dict) -> dict:
|
7 |
+
"""
|
8 |
+
Unify the action format and clean the task instruction.
|
9 |
+
|
10 |
+
DO NOT use python list, use tf.TensorArray instead.
|
11 |
+
"""
|
12 |
+
# Convert raw action to our action
|
13 |
+
|
14 |
+
origin_action = step['action']
|
15 |
+
step['action']={}
|
16 |
+
action=step['action']
|
17 |
+
action['terminate'] = step['is_terminal']
|
18 |
+
|
19 |
+
eef_pos=tf.cast(origin_action, dtype=tf.float32)
|
20 |
+
eef_ang=tf.cast(step['action_angle'][2:3], dtype=tf.float32)
|
21 |
+
eef_ang = euler_to_quaternion(tf.stack([0,0,eef_ang[0]], axis=0))
|
22 |
+
# No base found
|
23 |
+
|
24 |
+
# Concatenate the action
|
25 |
+
action['arm_concat'] = tf.concat([eef_pos,eef_ang],axis=0)
|
26 |
+
|
27 |
+
# Write the action format
|
28 |
+
action['format'] = tf.constant(
|
29 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w")
|
30 |
+
|
31 |
+
# Convert raw state to our state
|
32 |
+
state = step['observation']
|
33 |
+
# Concatenate the state
|
34 |
+
eef_pos=tf.cast(state['position'],dtype=tf.float32)
|
35 |
+
eef_ang=tf.cast(state['yaw'],dtype=tf.float32)
|
36 |
+
eef_ang = euler_to_rotation_matrix(tf.stack([0,0,eef_ang[0]],axis=0))
|
37 |
+
eef_ang = rotation_matrix_to_ortho6d(eef_ang)
|
38 |
+
state['arm_concat'] = tf.concat([eef_pos/100,eef_ang],axis=0)
|
39 |
+
# Write the state format
|
40 |
+
state['format'] = tf.constant(
|
41 |
+
"eef_pos_x,eef_pos_y,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
|
42 |
+
|
43 |
+
# Clean the task instruction
|
44 |
+
# Define the replacements (old, new) as a dictionary
|
45 |
+
replacements = {
|
46 |
+
'_': ' ',
|
47 |
+
'1f': ' ',
|
48 |
+
'4f': ' ',
|
49 |
+
'-': ' ',
|
50 |
+
'50': ' ',
|
51 |
+
'55': ' ',
|
52 |
+
'56': ' ',
|
53 |
+
|
54 |
+
}
|
55 |
+
instr = step['language_instruction']
|
56 |
+
instr = clean_task_instruction(instr, replacements)
|
57 |
+
step['observation']['natural_language_instruction'] = instr
|
58 |
+
|
59 |
+
return step
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
import tensorflow_datasets as tfds
|
64 |
+
from data.utils import dataset_to_path
|
65 |
+
|
66 |
+
DATASET_DIR = 'data/datasets/openx_embod'
|
67 |
+
DATASET_NAME = 'berkeley_gnm_sac_son'
|
68 |
+
# Load the dataset
|
69 |
+
dataset = tfds.builder_from_directory(
|
70 |
+
builder_dir=dataset_to_path(
|
71 |
+
DATASET_NAME, DATASET_DIR))
|
72 |
+
dataset = dataset.as_dataset(split='all')
|
73 |
+
|
74 |
+
# Inspect the dataset
|
75 |
+
for episode in dataset:
|
76 |
+
for step in episode['steps']:
|
77 |
+
print(step['action'][6:7])
|
78 |
+
|
data/preprocess_scripts/berkeley_rpt_converted_externally_to_rlds.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, quaternion_to_euler
|
4 |
+
|
5 |
+
|
6 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
7 |
+
"""
|
8 |
+
Convert terminate action to a boolean, where True means terminate.
|
9 |
+
"""
|
10 |
+
return tf.reduce_all(tf.equal(terminate_act, tf.constant([1, 0, 0], dtype=tf.int32)))
|
11 |
+
|
12 |
+
|
13 |
+
def process_step(step: dict) -> dict:
|
14 |
+
"""
|
15 |
+
Unify the action format and clean the task instruction.
|
16 |
+
|
17 |
+
DO NOT use python list, use tf.TensorArray instead.
|
18 |
+
"""
|
19 |
+
# Convert raw action to our action
|
20 |
+
action = step['action']
|
21 |
+
# Robot action, consists of [7 delta joint pos,1x gripper binary state].
|
22 |
+
delta_joint_pos = action[:7]
|
23 |
+
grip_open = tf.expand_dims(1 - action[7], axis=0)
|
24 |
+
|
25 |
+
# Concatenate the action
|
26 |
+
# action['arm_concat'] = tf.concat([eef_delta_pos, eef_ang, grip_open], axis=0)
|
27 |
+
step['action'] = {}
|
28 |
+
action = step['action']
|
29 |
+
action['arm_concat'] = tf.concat([delta_joint_pos, grip_open], axis=0)
|
30 |
+
action['terminate'] = step['is_terminal']
|
31 |
+
|
32 |
+
# Write the action format
|
33 |
+
action['format'] = tf.constant(
|
34 |
+
"arm_joint_0_delta_pos,arm_joint_1_delta_pos,arm_joint_2_delta_pos,arm_joint_3_delta_pos,arm_joint_4_delta_pos,arm_joint_5_delta_pos,arm_joint_6_delta_pos,gripper_open")
|
35 |
+
|
36 |
+
# Convert raw state to our state
|
37 |
+
state = step['observation']
|
38 |
+
# xArm joint positions (7 DoF).
|
39 |
+
arm_joint_pos = state['joint_pos']
|
40 |
+
# Binary gripper state (1 - closed, 0 - open)
|
41 |
+
grip_open = tf.expand_dims(1 - tf.cast(state['gripper'],dtype=tf.float32), axis=0)
|
42 |
+
|
43 |
+
# Concatenate the state
|
44 |
+
state['arm_concat'] = tf.concat([arm_joint_pos, grip_open], axis=0)
|
45 |
+
|
46 |
+
# Write the state format
|
47 |
+
state['format'] = tf.constant(
|
48 |
+
"arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,gripper_open")
|
49 |
+
|
50 |
+
# Clean the task instruction
|
51 |
+
# Define the replacements (old, new) as a dictionary
|
52 |
+
replacements = {
|
53 |
+
'_': ' ',
|
54 |
+
'1f': ' ',
|
55 |
+
'4f': ' ',
|
56 |
+
'-': ' ',
|
57 |
+
'50': ' ',
|
58 |
+
'55': ' ',
|
59 |
+
'56': ' ',
|
60 |
+
|
61 |
+
}
|
62 |
+
instr = step['language_instruction']
|
63 |
+
instr = clean_task_instruction(instr, replacements)
|
64 |
+
step['observation']['natural_language_instruction'] = instr
|
65 |
+
|
66 |
+
return step
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
import tensorflow_datasets as tfds
|
71 |
+
from data.utils import dataset_to_path
|
72 |
+
|
73 |
+
DATASET_DIR = 'data/datasets/openx_embod'
|
74 |
+
DATASET_NAME = 'fractal20220817_data'
|
75 |
+
# Load the dataset
|
76 |
+
dataset = tfds.builder_from_directory(
|
77 |
+
builder_dir=dataset_to_path(
|
78 |
+
DATASET_NAME, DATASET_DIR))
|
79 |
+
dataset = dataset.as_dataset(split='all')
|
80 |
+
|
81 |
+
# Inspect the dataset
|
82 |
+
for episode in dataset:
|
83 |
+
for step in episode['steps']:
|
84 |
+
print(step)
|
data/preprocess_scripts/calvin.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
from data.utils import clean_task_instruction, euler_to_rotation_matrix, rotation_matrix_to_ortho6d
|
3 |
+
import tensorflow as tf
|
4 |
+
import os
|
5 |
+
import fnmatch
|
6 |
+
import random
|
7 |
+
|
8 |
+
|
9 |
+
def _parse_function(proto):
|
10 |
+
keys_to_features = {
|
11 |
+
'action': tf.io.FixedLenFeature([], tf.string),
|
12 |
+
'robot_obs': tf.io.FixedLenFeature([], tf.string),
|
13 |
+
'rgb_static': tf.io.FixedLenFeature([], tf.string),
|
14 |
+
'rgb_gripper': tf.io.FixedLenFeature([], tf.string),
|
15 |
+
'terminate_episode': tf.io.FixedLenFeature([], tf.int64),
|
16 |
+
'instruction': tf.io.FixedLenFeature([], tf.string),
|
17 |
+
}
|
18 |
+
|
19 |
+
parsed_features = tf.io.parse_single_example(proto, keys_to_features)
|
20 |
+
|
21 |
+
action = tf.io.parse_tensor(parsed_features['action'], out_type=tf.float64)
|
22 |
+
robot_obs = tf.io.parse_tensor(parsed_features['robot_obs'], out_type=tf.float64)
|
23 |
+
rgb_static = tf.io.parse_tensor(parsed_features['rgb_static'], out_type=tf.uint8)
|
24 |
+
rgb_gripper = tf.io.parse_tensor(parsed_features['rgb_gripper'], out_type=tf.uint8)
|
25 |
+
instruction = parsed_features['instruction']
|
26 |
+
terminate_episode = tf.cast(parsed_features['terminate_episode'], tf.int64)
|
27 |
+
|
28 |
+
action = tf.reshape(action, [7])
|
29 |
+
action = tf.cast(action, tf.float32)
|
30 |
+
robot_obs = tf.reshape(robot_obs, [15])
|
31 |
+
robot_obs = tf.cast(robot_obs, tf.float32)
|
32 |
+
rgb_static = tf.reshape(rgb_static, [200, 200, 3])
|
33 |
+
rgb_gripper = tf.reshape(rgb_gripper, [84, 84, 3])
|
34 |
+
# RGB to BGR
|
35 |
+
# rgb_static = rgb_static[:, :, ::-1]
|
36 |
+
# rgb_gripper = rgb_gripper[:, :, ::-1]
|
37 |
+
|
38 |
+
return {
|
39 |
+
'action': action,
|
40 |
+
'observation':{
|
41 |
+
'robot_obs': robot_obs,
|
42 |
+
'rgb_static': rgb_static,
|
43 |
+
'rgb_gripper': rgb_gripper,
|
44 |
+
},
|
45 |
+
'instruction': instruction,
|
46 |
+
'terminate_episode': terminate_episode
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
def dataset_generator_from_tfrecords(seed):
|
51 |
+
tfrecord_path = './data/datasets/calvin/tfrecords/'
|
52 |
+
filepaths = []
|
53 |
+
for root, dirs, files in os.walk(tfrecord_path):
|
54 |
+
for filename in fnmatch.filter(files, '*.tfrecord'):
|
55 |
+
filepath = os.path.join(root, filename)
|
56 |
+
filepaths.append(filepath)
|
57 |
+
|
58 |
+
random.seed(seed)
|
59 |
+
random.shuffle(filepaths)
|
60 |
+
for filepath in filepaths:
|
61 |
+
raw_dataset = tf.data.TFRecordDataset(filepath)
|
62 |
+
dataset = raw_dataset.map(_parse_function)
|
63 |
+
yield {
|
64 |
+
'steps': dataset
|
65 |
+
}
|
66 |
+
|
67 |
+
|
68 |
+
def load_dataset(seed):
|
69 |
+
dataset = tf.data.Dataset.from_generator(
|
70 |
+
lambda: dataset_generator_from_tfrecords(seed),
|
71 |
+
output_signature={
|
72 |
+
'steps': tf.data.DatasetSpec(
|
73 |
+
element_spec={
|
74 |
+
'action': tf.TensorSpec(shape=(7,), dtype=tf.float32),
|
75 |
+
'observation':{
|
76 |
+
'robot_obs': tf.TensorSpec(shape=(15,), dtype=tf.float32),
|
77 |
+
'rgb_static': tf.TensorSpec(shape=(200,200,3), dtype=tf.uint8),
|
78 |
+
'rgb_gripper': tf.TensorSpec(shape=(84,84,3), dtype=tf.uint8),
|
79 |
+
},
|
80 |
+
'instruction': tf.TensorSpec(shape=(), dtype=tf.string),
|
81 |
+
'terminate_episode': tf.TensorSpec(shape=(), dtype=tf.int64),
|
82 |
+
}
|
83 |
+
)
|
84 |
+
}
|
85 |
+
)
|
86 |
+
|
87 |
+
return dataset
|
88 |
+
|
89 |
+
|
90 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
91 |
+
"""
|
92 |
+
Convert terminate action to a boolean, where True means terminate.
|
93 |
+
"""
|
94 |
+
return tf.where(
|
95 |
+
tf.equal(terminate_act, tf.constant(0, dtype=tf.int64)),
|
96 |
+
tf.constant(False),tf.constant(True))
|
97 |
+
|
98 |
+
|
99 |
+
def process_step(step: dict) -> dict:
|
100 |
+
"""
|
101 |
+
Unify the action format and clean the task instruction.
|
102 |
+
|
103 |
+
DO NOT use python list, use tf.TensorArray instead.
|
104 |
+
"""
|
105 |
+
# Convert raw action to our action
|
106 |
+
old_action = step['action']
|
107 |
+
step['action'] = {}
|
108 |
+
action = step['action']
|
109 |
+
step['action']['terminate'] = terminate_act_to_bool(step['terminate_episode'])
|
110 |
+
# ['actions']
|
111 |
+
# (dtype=np.float32, shape=(7,))
|
112 |
+
# tcp position (3): x,y,z in absolute world coordinates
|
113 |
+
# tcp orientation (3): euler angles x,y,z in absolute world coordinates
|
114 |
+
# gripper_action (1): binary (close = -1, open = 1)
|
115 |
+
eef_pos = old_action[:3]
|
116 |
+
eef_ang = euler_to_rotation_matrix(old_action[3:6])
|
117 |
+
eef_ang = rotation_matrix_to_ortho6d(eef_ang)
|
118 |
+
gripper_open = (old_action[6] + 1) / 2
|
119 |
+
gripper_open = tf.expand_dims(gripper_open, axis=0)
|
120 |
+
|
121 |
+
# # No base found
|
122 |
+
arm_action = tf.concat([eef_pos, eef_ang, gripper_open], axis=0)
|
123 |
+
action['arm_concat'] = arm_action
|
124 |
+
# # Write the action format
|
125 |
+
action['format'] = tf.constant(
|
126 |
+
"eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5,gripper_open")
|
127 |
+
|
128 |
+
state = step['observation']
|
129 |
+
# ['robot_obs']
|
130 |
+
# (dtype=np.float32, shape=(15,))
|
131 |
+
# tcp position (3): x,y,z in world coordinates
|
132 |
+
# tcp orientation (3): euler angles x,y,z in world coordinates
|
133 |
+
# gripper opening width (1): in meter
|
134 |
+
# arm_joint_states (7): in rad
|
135 |
+
# gripper_action (1): binary (close = -1, open = 1)
|
136 |
+
eef_pos = state['robot_obs'][:3]
|
137 |
+
eef_ang = euler_to_rotation_matrix(state['robot_obs'][3:6])
|
138 |
+
eef_ang = rotation_matrix_to_ortho6d(eef_ang)
|
139 |
+
gripper_open = (state['robot_obs'][14] + 1) / 2
|
140 |
+
gripper_open = tf.expand_dims(gripper_open, axis=0)
|
141 |
+
qpos = state['robot_obs'][7:14]
|
142 |
+
|
143 |
+
state['arm_concat'] = tf.concat([qpos,gripper_open,eef_pos,eef_ang], axis=0)
|
144 |
+
# # Write the state format
|
145 |
+
state['format'] = tf.constant(
|
146 |
+
"arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,gripper_open,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
|
147 |
+
|
148 |
+
# Clean the task instruction
|
149 |
+
# Define the replacements (old, new) as a dictionary
|
150 |
+
replacements = {
|
151 |
+
'_': ' ',
|
152 |
+
'1f': ' ',
|
153 |
+
'4f': ' ',
|
154 |
+
'-': ' ',
|
155 |
+
'50': ' ',
|
156 |
+
'55': ' ',
|
157 |
+
'56': ' ',
|
158 |
+
|
159 |
+
}
|
160 |
+
instr = step['instruction']
|
161 |
+
instr= clean_task_instruction(instr, replacements)
|
162 |
+
step['observation']['natural_language_instruction'] = instr
|
163 |
+
|
164 |
+
return step
|
165 |
+
|
166 |
+
|
167 |
+
if __name__ == "__main__":
|
168 |
+
import tensorflow_datasets as tfds
|
169 |
+
from data.utils import dataset_to_path
|
170 |
+
|
171 |
+
# Load the dataset
|
172 |
+
dataset = load_dataset(1717055919)
|
173 |
+
for data in dataset.take(1):
|
174 |
+
for step in data['steps']:
|
175 |
+
step = process_step(step)
|
176 |
+
print(step['observation']['natural_language_instruction'])
|
data/preprocess_scripts/cmu_franka_exploration_dataset_converted_externally_to_rlds.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, quaternion_to_euler,euler_to_quaternion
|
4 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
5 |
+
"""
|
6 |
+
Convert terminate action to a boolean, where True means terminate.
|
7 |
+
"""
|
8 |
+
return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
|
9 |
+
|
10 |
+
def process_step(step: dict) -> dict:
|
11 |
+
"""
|
12 |
+
Unify the action format and clean the task instruction.
|
13 |
+
|
14 |
+
DO NOT use python list, use tf.TensorArray instead.
|
15 |
+
"""
|
16 |
+
# Convert raw action to our action
|
17 |
+
|
18 |
+
origin_action = step['action']
|
19 |
+
step['action']={}
|
20 |
+
action=step['action']
|
21 |
+
action['terminate']=terminate_act_to_bool(origin_action[7])
|
22 |
+
|
23 |
+
# gripper_open: 1-open, 0-closed
|
24 |
+
|
25 |
+
eef_pos=origin_action[:3]
|
26 |
+
eef_ang=origin_action[3:6]
|
27 |
+
eef_ang = euler_to_quaternion(eef_ang)
|
28 |
+
grip_open=origin_action[6:7]
|
29 |
+
# No base found
|
30 |
+
|
31 |
+
# Concatenate the action
|
32 |
+
action['arm_concat'] = tf.concat([eef_pos,eef_ang,grip_open],axis=0)
|
33 |
+
|
34 |
+
# Write the action format
|
35 |
+
action['format'] = tf.constant(
|
36 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
|
37 |
+
|
38 |
+
# No state found
|
39 |
+
|
40 |
+
# Clean the task instruction
|
41 |
+
# Define the replacements (old, new) as a dictionary
|
42 |
+
replacements = {
|
43 |
+
'_': ' ',
|
44 |
+
'1f': ' ',
|
45 |
+
'4f': ' ',
|
46 |
+
'-': ' ',
|
47 |
+
'50': ' ',
|
48 |
+
'55': ' ',
|
49 |
+
'56': ' ',
|
50 |
+
|
51 |
+
}
|
52 |
+
instr = step['language_instruction']
|
53 |
+
instr = clean_task_instruction(instr, replacements)
|
54 |
+
step['observation']['natural_language_instruction'] = instr
|
55 |
+
|
56 |
+
return step
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
import tensorflow_datasets as tfds
|
61 |
+
from data.utils import dataset_to_path
|
62 |
+
|
63 |
+
DATASET_DIR = 'data/datasets/openx_embod'
|
64 |
+
DATASET_NAME = 'cmu_franka_exploration_dataset_converted_externally_to_rlds'
|
65 |
+
# Load the dataset
|
66 |
+
dataset = tfds.builder_from_directory(
|
67 |
+
builder_dir=dataset_to_path(
|
68 |
+
DATASET_NAME, DATASET_DIR))
|
69 |
+
dataset = dataset.as_dataset(split='all')
|
70 |
+
|
71 |
+
# Inspect the dataset
|
72 |
+
for episode in dataset:
|
73 |
+
for step in episode['steps']:
|
74 |
+
print(step['action'][6:7])
|
75 |
+
|
data/preprocess_scripts/cmu_play_fusion.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, quaternion_to_euler
|
4 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
5 |
+
"""
|
6 |
+
Convert terminate action to a boolean, where True means terminate.
|
7 |
+
"""
|
8 |
+
return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
|
9 |
+
|
10 |
+
def process_step(step: dict) -> dict:
|
11 |
+
"""
|
12 |
+
Unify the action format and clean the task instruction.
|
13 |
+
|
14 |
+
DO NOT use python list, use tf.TensorArray instead.
|
15 |
+
"""
|
16 |
+
# Convert raw action to our action
|
17 |
+
|
18 |
+
origin_action = step['action']
|
19 |
+
step['action']={}
|
20 |
+
action=step['action']
|
21 |
+
action['terminate']=terminate_act_to_bool(origin_action[8])
|
22 |
+
|
23 |
+
|
24 |
+
eef_pos=origin_action[:3]
|
25 |
+
# eef_ang=quaternion_to_euler(origin_action[3:7])
|
26 |
+
eef_ang = origin_action[3:7]
|
27 |
+
grip_open=origin_action[7:8]
|
28 |
+
# No base found
|
29 |
+
|
30 |
+
# Concatenate the action
|
31 |
+
action['arm_concat'] = tf.concat([eef_pos,eef_ang,grip_open],axis=0)
|
32 |
+
|
33 |
+
# Write the action format
|
34 |
+
action['format'] = tf.constant(
|
35 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
|
36 |
+
|
37 |
+
# Convert raw state to our state
|
38 |
+
state = step['observation']
|
39 |
+
# Concatenate the state
|
40 |
+
arm_joint_ang=state['state'][:7]
|
41 |
+
grip_open=state['state'][7:8] * 11.765 # rescale to [0, 1]
|
42 |
+
state['arm_concat'] = tf.concat([arm_joint_ang,grip_open],axis=0)
|
43 |
+
# Write the state format
|
44 |
+
state['format'] = tf.constant(
|
45 |
+
"arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,gripper_joint_0_pos")
|
46 |
+
|
47 |
+
# Clean the task instruction
|
48 |
+
# Define the replacements (old, new) as a dictionary
|
49 |
+
replacements = {
|
50 |
+
'_': ' ',
|
51 |
+
'1f': ' ',
|
52 |
+
'4f': ' ',
|
53 |
+
'-': ' ',
|
54 |
+
'50': ' ',
|
55 |
+
'55': ' ',
|
56 |
+
'56': ' ',
|
57 |
+
|
58 |
+
}
|
59 |
+
instr = step['language_instruction']
|
60 |
+
instr = clean_task_instruction(instr, replacements)
|
61 |
+
step['observation']['natural_language_instruction'] = instr
|
62 |
+
|
63 |
+
return step
|
64 |
+
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
import tensorflow_datasets as tfds
|
68 |
+
from data.utils import dataset_to_path
|
69 |
+
|
70 |
+
DATASET_DIR = 'data/datasets/openx_embod'
|
71 |
+
DATASET_NAME = 'cmu_play_fusion'
|
72 |
+
# Load the dataset
|
73 |
+
dataset = tfds.builder_from_directory(
|
74 |
+
builder_dir=dataset_to_path(
|
75 |
+
DATASET_NAME, DATASET_DIR))
|
76 |
+
dataset = dataset.as_dataset(split='all')
|
77 |
+
|
78 |
+
# Inspect the dataset
|
79 |
+
for episode in dataset:
|
80 |
+
for step in episode['steps']:
|
81 |
+
print(step['action'][6:7])
|
82 |
+
|
data/preprocess_scripts/cmu_stretch.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, quaternion_to_euler,euler_to_quaternion
|
4 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
5 |
+
"""
|
6 |
+
Convert terminate action to a boolean, where True means terminate.
|
7 |
+
"""
|
8 |
+
return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
|
9 |
+
|
10 |
+
def process_step(step: dict) -> dict:
|
11 |
+
"""
|
12 |
+
Unify the action format and clean the task instruction.
|
13 |
+
|
14 |
+
DO NOT use python list, use tf.TensorArray instead.
|
15 |
+
"""
|
16 |
+
# Convert raw action to our action
|
17 |
+
|
18 |
+
origin_action = step['action']
|
19 |
+
step['action']={}
|
20 |
+
action=step['action']
|
21 |
+
action['terminate']=terminate_act_to_bool(origin_action[7])
|
22 |
+
|
23 |
+
|
24 |
+
eef_pos=origin_action[:3]
|
25 |
+
eef_ang=origin_action[3:6]
|
26 |
+
eef_ang = euler_to_quaternion(eef_ang)
|
27 |
+
grip_open=origin_action[6:7]
|
28 |
+
# No base found
|
29 |
+
|
30 |
+
# Concatenate the action
|
31 |
+
action['arm_concat'] = tf.concat([eef_pos,eef_ang,grip_open],axis=0)
|
32 |
+
|
33 |
+
# Write the action format
|
34 |
+
action['format'] = tf.constant(
|
35 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
|
36 |
+
|
37 |
+
# Convert raw state to our state
|
38 |
+
state = step['observation']
|
39 |
+
# Concatenate the state
|
40 |
+
eef_pos_x = state['state'][0:1]
|
41 |
+
eef_pos_z = state['state'][2:3]
|
42 |
+
grip_open = state['state'][3:4]
|
43 |
+
state['arm_concat'] = tf.concat(
|
44 |
+
[eef_pos_x, eef_pos_z, grip_open], axis=0)
|
45 |
+
# Write the state format
|
46 |
+
state['format'] = tf.constant(
|
47 |
+
"eef_pos_x,eef_pos_z,gripper_open")
|
48 |
+
|
49 |
+
# Clean the task instruction
|
50 |
+
# Define the replacements (old, new) as a dictionary
|
51 |
+
replacements = {
|
52 |
+
'_': ' ',
|
53 |
+
'1f': ' ',
|
54 |
+
'4f': ' ',
|
55 |
+
'-': ' ',
|
56 |
+
'50': ' ',
|
57 |
+
'55': ' ',
|
58 |
+
'56': ' ',
|
59 |
+
|
60 |
+
}
|
61 |
+
instr = step['language_instruction']
|
62 |
+
instr = clean_task_instruction(instr, replacements)
|
63 |
+
step['observation']['natural_language_instruction'] = instr
|
64 |
+
|
65 |
+
return step
|
66 |
+
|
67 |
+
|
68 |
+
if __name__ == "__main__":
|
69 |
+
import tensorflow_datasets as tfds
|
70 |
+
from data.utils import dataset_to_path
|
71 |
+
|
72 |
+
DATASET_DIR = 'data/datasets/openx_embod'
|
73 |
+
DATASET_NAME = 'cmu_stretch'
|
74 |
+
# Load the dataset
|
75 |
+
dataset = tfds.builder_from_directory(
|
76 |
+
builder_dir=dataset_to_path(
|
77 |
+
DATASET_NAME, DATASET_DIR))
|
78 |
+
dataset = dataset.as_dataset(split='all')
|
79 |
+
|
80 |
+
# Inspect the dataset
|
81 |
+
for episode in dataset:
|
82 |
+
for step in episode['steps']:
|
83 |
+
print(step['action'][6:7])
|
84 |
+
|
data/preprocess_scripts/droid.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, euler_to_rotation_matrix, rotation_matrix_to_ortho6d
|
4 |
+
|
5 |
+
|
6 |
+
def process_step(step: dict) -> dict:
|
7 |
+
"""
|
8 |
+
Unify the action format and clean the task instruction.
|
9 |
+
|
10 |
+
DO NOT use python list, use tf.TensorArray instead.
|
11 |
+
"""
|
12 |
+
# Convert raw action to our action
|
13 |
+
action_dict = step['action_dict']
|
14 |
+
|
15 |
+
# Robot action
|
16 |
+
eef_pos = action_dict['cartesian_position'][:3]
|
17 |
+
eef_ang = action_dict['cartesian_position'][3:6]
|
18 |
+
eef_ang = euler_to_rotation_matrix(eef_ang)
|
19 |
+
eef_ang = rotation_matrix_to_ortho6d(eef_ang)
|
20 |
+
eef_pos_vel = action_dict['cartesian_velocity'][:3]
|
21 |
+
eef_ang_vel = action_dict['cartesian_velocity'][3:6]
|
22 |
+
joint_pos = action_dict['joint_position']
|
23 |
+
joint_vel = action_dict['joint_velocity']
|
24 |
+
grip_pos = action_dict['gripper_position']
|
25 |
+
grip_vel = action_dict['gripper_velocity']
|
26 |
+
|
27 |
+
# Concatenate the action
|
28 |
+
step['action'] = {}
|
29 |
+
action = step['action']
|
30 |
+
|
31 |
+
arm_action = tf.concat([eef_pos, eef_ang, eef_pos_vel, eef_ang_vel, joint_pos, joint_vel, grip_pos, grip_vel], axis=0)
|
32 |
+
action['arm_concat'] = arm_action
|
33 |
+
action['terminate'] = step['is_terminal']
|
34 |
+
|
35 |
+
# Write the action format
|
36 |
+
action['format'] = tf.constant(
|
37 |
+
"eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5,eef_vel_x,eef_vel_y,eef_vel_z,eef_angular_vel_roll,eef_angular_vel_pitch,eef_angular_vel_yaw,arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,arm_joint_0_vel,arm_joint_1_vel,arm_joint_2_vel,arm_joint_3_vel,arm_joint_4_vel,arm_joint_5_vel,arm_joint_6_vel,gripper_joint_0_pos,gripper_joint_0_vel")
|
38 |
+
|
39 |
+
# Convert raw state to our state
|
40 |
+
# Robot state
|
41 |
+
state = step['observation']
|
42 |
+
eef_pos = state['cartesian_position'][:3]
|
43 |
+
eef_ang = state['cartesian_position'][3:6]
|
44 |
+
eef_ang = euler_to_rotation_matrix(eef_ang)
|
45 |
+
eef_ang = rotation_matrix_to_ortho6d(eef_ang)
|
46 |
+
joint_pos = state['joint_position']
|
47 |
+
grip_pos = 1 - state['gripper_position']
|
48 |
+
|
49 |
+
# Concatenate the state
|
50 |
+
state['arm_concat'] = tf.concat([
|
51 |
+
joint_pos,grip_pos,eef_pos,eef_ang], axis=0)
|
52 |
+
|
53 |
+
|
54 |
+
# Write the state format
|
55 |
+
state['format'] = tf.constant(
|
56 |
+
"arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,gripper_joint_0_pos,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
|
57 |
+
|
58 |
+
# Clean the task instruction
|
59 |
+
# Define the replacements (old, new) as a dictionary
|
60 |
+
replacements = {
|
61 |
+
'_': ' ',
|
62 |
+
'1f': ' ',
|
63 |
+
'4f': ' ',
|
64 |
+
'-': ' ',
|
65 |
+
'50': ' ',
|
66 |
+
'55': ' ',
|
67 |
+
'56': ' ',
|
68 |
+
|
69 |
+
}
|
70 |
+
instr = step['language_instruction']
|
71 |
+
instr = clean_task_instruction(instr, replacements)
|
72 |
+
step['observation']['natural_language_instruction'] = instr
|
73 |
+
|
74 |
+
return step
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
pass
|
data/preprocess_scripts/fractal20220817_data.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, euler_to_quaternion, quaternion_to_rotation_matrix,\
|
4 |
+
rotation_matrix_to_ortho6d
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
9 |
+
"""
|
10 |
+
Convert terminate action to a boolean, where True means terminate.
|
11 |
+
"""
|
12 |
+
return tf.reduce_all(tf.equal(terminate_act, tf.constant([1, 0, 0], dtype=tf.int32)))
|
13 |
+
|
14 |
+
|
15 |
+
def process_step(step: dict) -> dict:
|
16 |
+
"""
|
17 |
+
Unify the action format and clean the task instruction.
|
18 |
+
|
19 |
+
DO NOT use python list, use tf.TensorArray instead.
|
20 |
+
"""
|
21 |
+
# Convert raw action to our action
|
22 |
+
action = step['action']
|
23 |
+
action['terminate'] = terminate_act_to_bool(action['terminate_episode'])
|
24 |
+
|
25 |
+
eef_delta_pos = action['world_vector']
|
26 |
+
eef_ang = action['rotation_delta']
|
27 |
+
eef_ang = euler_to_quaternion(eef_ang)
|
28 |
+
grip_open = 1 - (action['gripper_closedness_action'] + 1) / 2
|
29 |
+
# Multiplied by 3 Hz to get units m/s and rad/s
|
30 |
+
base_delta_pos = action['base_displacement_vector'] * 3
|
31 |
+
base_delta_ang = action['base_displacement_vertical_rotation'] * 3
|
32 |
+
|
33 |
+
# Concatenate the action
|
34 |
+
arm_action = tf.concat([eef_delta_pos, eef_ang, grip_open], axis=0)
|
35 |
+
action['arm_concat'] = arm_action
|
36 |
+
base_action = tf.concat([base_delta_pos, base_delta_ang], axis=0)
|
37 |
+
action['base_concat'] = base_action
|
38 |
+
|
39 |
+
# Write the action format
|
40 |
+
action['format'] = tf.constant(
|
41 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open,base_vel_x,base_vel_y,base_angular_vel")
|
42 |
+
|
43 |
+
# Convert raw state to our state
|
44 |
+
state = step['observation']
|
45 |
+
eef_pos = state['base_pose_tool_reached'][:3]
|
46 |
+
# eef_ang = quaternion_to_euler(state['base_pose_tool_reached'][3:])
|
47 |
+
eef_ang = quaternion_to_rotation_matrix(state['base_pose_tool_reached'][3:])
|
48 |
+
eef_ang = rotation_matrix_to_ortho6d(eef_ang)
|
49 |
+
grip_open = 1 - state['gripper_closed']
|
50 |
+
|
51 |
+
# Concatenate the state
|
52 |
+
state['arm_concat'] = tf.concat([eef_pos, eef_ang, grip_open], axis=0)
|
53 |
+
|
54 |
+
# Write the state format
|
55 |
+
state['format'] = tf.constant(
|
56 |
+
"eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5,gripper_open")
|
57 |
+
|
58 |
+
# Clean the task instruction
|
59 |
+
# Define the replacements (old, new) as a dictionary
|
60 |
+
replacements = {
|
61 |
+
'_': ' ',
|
62 |
+
'1f': ' ',
|
63 |
+
'4f': ' ',
|
64 |
+
'-': ' ',
|
65 |
+
'50': ' ',
|
66 |
+
'55': ' ',
|
67 |
+
'56': ' ',
|
68 |
+
|
69 |
+
}
|
70 |
+
instr = step['observation']['natural_language_instruction']
|
71 |
+
instr = clean_task_instruction(instr, replacements)
|
72 |
+
step['observation']['natural_language_instruction'] = instr
|
73 |
+
|
74 |
+
return step
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
import tensorflow_datasets as tfds
|
79 |
+
from data.utils import dataset_to_path
|
80 |
+
|
81 |
+
DATASET_DIR = 'data/datasets/openx_embod'
|
82 |
+
DATASET_NAME = 'fractal20220817_data'
|
83 |
+
# Load the dataset
|
84 |
+
dataset = tfds.builder_from_directory(
|
85 |
+
builder_dir=dataset_to_path(
|
86 |
+
DATASET_NAME, DATASET_DIR))
|
87 |
+
dataset = dataset.as_dataset(split='all')
|
88 |
+
|
89 |
+
# Inspect the dataset
|
90 |
+
for episode in dataset:
|
91 |
+
for step in episode['steps']:
|
92 |
+
print(step)
|
data/preprocess_scripts/iamlab_cmu_pickup_insert_converted_externally_to_rlds.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction,quaternion_to_euler
|
4 |
+
|
5 |
+
def process_step(step: dict) -> dict:
|
6 |
+
"""
|
7 |
+
Unify the action format and clean the task instruction.
|
8 |
+
|
9 |
+
DO NOT use python list, use tf.TensorArray instead.
|
10 |
+
"""
|
11 |
+
# Convert raw action to our action
|
12 |
+
|
13 |
+
origin_action = step['action']
|
14 |
+
step['action']={}
|
15 |
+
action=step['action']
|
16 |
+
|
17 |
+
eef_delta_pos = origin_action[:3]
|
18 |
+
# delta ZYX euler angles
|
19 |
+
# eef_ang=quaternion_to_euler(origin_action[3:7])
|
20 |
+
eef_ang = origin_action[3:7]
|
21 |
+
grip_open=origin_action[7:8]
|
22 |
+
|
23 |
+
# No base found
|
24 |
+
|
25 |
+
# Concatenate the action
|
26 |
+
action['arm_concat'] = tf.concat([eef_delta_pos,eef_ang,grip_open],axis=0)
|
27 |
+
|
28 |
+
# Write the action format
|
29 |
+
action['format'] = tf.constant(
|
30 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
|
31 |
+
|
32 |
+
# Convert raw state to our state
|
33 |
+
state = step['observation']
|
34 |
+
# Concatenate the state
|
35 |
+
# 7x robot joint angles, 1x gripper status, 6x joint torques, 6x end-effector force
|
36 |
+
arm_joint_ang=state['state'][:7]
|
37 |
+
|
38 |
+
grip_open=state['state'][7:8]
|
39 |
+
|
40 |
+
state['arm_concat'] = tf.concat([arm_joint_ang,grip_open],axis=0)
|
41 |
+
|
42 |
+
# Write the state format
|
43 |
+
state['format'] = tf.constant(
|
44 |
+
"arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,gripper_open")
|
45 |
+
|
46 |
+
# Clean the task instruction
|
47 |
+
# Define the replacements (old, new) as a dictionary
|
48 |
+
replacements = {
|
49 |
+
'_': ' ',
|
50 |
+
'1f': ' ',
|
51 |
+
'4f': ' ',
|
52 |
+
'-': ' ',
|
53 |
+
'50': ' ',
|
54 |
+
'55': ' ',
|
55 |
+
'56': ' ',
|
56 |
+
|
57 |
+
}
|
58 |
+
instr = step['language_instruction']
|
59 |
+
instr = clean_task_instruction(instr, replacements)
|
60 |
+
step['observation']['natural_language_instruction'] = instr
|
61 |
+
|
62 |
+
return step
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
import tensorflow_datasets as tfds
|
67 |
+
from data.utils import dataset_to_path
|
68 |
+
|
69 |
+
DATASET_DIR = 'data/datasets/openx_embod'
|
70 |
+
DATASET_NAME = 'iamlab_cmu_pickup_insert_converted_externally_to_rlds'
|
71 |
+
# Load the dataset
|
72 |
+
dataset = tfds.builder_from_directory(
|
73 |
+
builder_dir=dataset_to_path(
|
74 |
+
DATASET_NAME, DATASET_DIR))
|
75 |
+
dataset = dataset.as_dataset(split='all')
|
76 |
+
|
77 |
+
# Inspect the dataset
|
78 |
+
for episode in dataset:
|
79 |
+
for step in episode['steps']:
|
80 |
+
print(step)
|
data/preprocess_scripts/libero_goal_no_noops.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, euler_to_rotation_matrix, rotation_matrix_to_ortho6d
|
4 |
+
|
5 |
+
|
6 |
+
def process_step(step: dict) -> dict:
|
7 |
+
"""
|
8 |
+
Unify the action format and clean the task instruction.
|
9 |
+
|
10 |
+
DO NOT use python list, use tf.TensorArray instead.
|
11 |
+
"""
|
12 |
+
# Convert raw action to our action
|
13 |
+
action_dict = step['action']
|
14 |
+
|
15 |
+
# Robot action
|
16 |
+
# eef_pos = action_dict['ee_pos'][:3]
|
17 |
+
# eef_ang = action_dict['ee_pos'][3:6]
|
18 |
+
# eef_ang = euler_to_rotation_matrix(eef_ang)
|
19 |
+
# eef_ang = rotation_matrix_to_ortho6d(eef_ang)
|
20 |
+
eef_pos_vel = action_dict[:3]
|
21 |
+
eef_ang_vel = action_dict[3:6]
|
22 |
+
# joint_pos = action_dict['joint_pos'][:-1]
|
23 |
+
# joint_vel = action_dict['delta_joint'][:-1]
|
24 |
+
grip_pos = 1 - tf.clip_by_value(action_dict[-1:], 0, 1)
|
25 |
+
|
26 |
+
# grip_vel = action_dict['gripper_velocity']
|
27 |
+
|
28 |
+
# Concatenate the action
|
29 |
+
step['action'] = {}
|
30 |
+
action = step['action']
|
31 |
+
|
32 |
+
arm_action = tf.concat([eef_pos_vel, eef_ang_vel, grip_pos], axis=0)
|
33 |
+
action['arm_concat'] = arm_action
|
34 |
+
# action['terminate'] = step['is_terminal']
|
35 |
+
|
36 |
+
# Write the action format
|
37 |
+
action['format'] = tf.constant(
|
38 |
+
"eef_vel_x,eef_vel_y,eef_vel_z,eef_angular_vel_roll,eef_angular_vel_pitch,eef_angular_vel_yaw,gripper_joint_0_pos")
|
39 |
+
|
40 |
+
# Convert raw state to our state
|
41 |
+
# Robot state
|
42 |
+
state = step['observation']
|
43 |
+
# print(state.keys())
|
44 |
+
# image = step['observation']['image']
|
45 |
+
eef_pos = state['state'][:3]
|
46 |
+
eef_ang = state['state'][3:6]
|
47 |
+
eef_ang = euler_to_rotation_matrix(eef_ang)
|
48 |
+
eef_ang = rotation_matrix_to_ortho6d(eef_ang)
|
49 |
+
# joint_pos = state['joint_pos'][:-1]
|
50 |
+
grip_pos = state['state'][-2:]
|
51 |
+
|
52 |
+
# Concatenate the state
|
53 |
+
state['arm_concat'] = tf.concat([
|
54 |
+
grip_pos,eef_pos,eef_ang], axis=0)
|
55 |
+
|
56 |
+
|
57 |
+
# Write the state format
|
58 |
+
state['format'] = tf.constant(
|
59 |
+
"gripper_joint_0_pos,gripper_joint_1_pos,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
|
60 |
+
|
61 |
+
# Clean the task instruction
|
62 |
+
# Define the replacements (old, new) as a dictionary
|
63 |
+
replacements = {
|
64 |
+
'_': ' ',
|
65 |
+
'1f': ' ',
|
66 |
+
'4f': ' ',
|
67 |
+
'-': ' ',
|
68 |
+
'50': ' ',
|
69 |
+
'55': ' ',
|
70 |
+
'56': ' ',
|
71 |
+
|
72 |
+
}
|
73 |
+
instr = step['language_instruction']
|
74 |
+
# instr = clean_task_instruction(instr, replacements)
|
75 |
+
step['observation'] = state
|
76 |
+
step['observation']['natural_language_instruction'] = instr
|
77 |
+
|
78 |
+
return step
|
79 |
+
|
80 |
+
|
81 |
+
if __name__ == "__main__":
|
82 |
+
pass
|
data/preprocess_scripts/libero_spatial_no_noops.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, euler_to_rotation_matrix, rotation_matrix_to_ortho6d
|
4 |
+
|
5 |
+
|
6 |
+
def process_step(step: dict) -> dict:
|
7 |
+
"""
|
8 |
+
Unify the action format and clean the task instruction.
|
9 |
+
|
10 |
+
DO NOT use python list, use tf.TensorArray instead.
|
11 |
+
"""
|
12 |
+
# Convert raw action to our action
|
13 |
+
action_dict = step['action']
|
14 |
+
|
15 |
+
# Robot action
|
16 |
+
# eef_pos = action_dict['ee_pos'][:3]
|
17 |
+
# eef_ang = action_dict['ee_pos'][3:6]
|
18 |
+
# eef_ang = euler_to_rotation_matrix(eef_ang)
|
19 |
+
# eef_ang = rotation_matrix_to_ortho6d(eef_ang)
|
20 |
+
eef_pos_vel = action_dict[:3]
|
21 |
+
eef_ang_vel = action_dict[3:6]
|
22 |
+
# joint_pos = action_dict['joint_pos'][:-1]
|
23 |
+
# joint_vel = action_dict['delta_joint'][:-1]
|
24 |
+
grip_pos = 1 - tf.clip_by_value(action_dict[-1:], 0, 1)
|
25 |
+
|
26 |
+
# grip_vel = action_dict['gripper_velocity']
|
27 |
+
|
28 |
+
# Concatenate the action
|
29 |
+
step['action'] = {}
|
30 |
+
action = step['action']
|
31 |
+
|
32 |
+
arm_action = tf.concat([eef_pos_vel, eef_ang_vel, grip_pos], axis=0)
|
33 |
+
action['arm_concat'] = arm_action
|
34 |
+
# action['terminate'] = step['is_terminal']
|
35 |
+
|
36 |
+
# Write the action format
|
37 |
+
action['format'] = tf.constant(
|
38 |
+
"eef_vel_x,eef_vel_y,eef_vel_z,eef_angular_vel_roll,eef_angular_vel_pitch,eef_angular_vel_yaw,gripper_joint_0_pos")
|
39 |
+
|
40 |
+
# Convert raw state to our state
|
41 |
+
# Robot state
|
42 |
+
state = step['observation']
|
43 |
+
# print(state.keys())
|
44 |
+
# image = step['observation']['image']
|
45 |
+
eef_pos = state['state'][:3]
|
46 |
+
eef_ang = state['state'][3:6]
|
47 |
+
eef_ang = euler_to_rotation_matrix(eef_ang)
|
48 |
+
eef_ang = rotation_matrix_to_ortho6d(eef_ang)
|
49 |
+
# joint_pos = state['joint_pos'][:-1]
|
50 |
+
grip_pos = state['state'][-2:]
|
51 |
+
|
52 |
+
# Concatenate the state
|
53 |
+
state['arm_concat'] = tf.concat([
|
54 |
+
grip_pos,eef_pos,eef_ang], axis=0)
|
55 |
+
|
56 |
+
|
57 |
+
# Write the state format
|
58 |
+
state['format'] = tf.constant(
|
59 |
+
"gripper_joint_0_pos,gripper_joint_1_pos,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
|
60 |
+
|
61 |
+
# Clean the task instruction
|
62 |
+
# Define the replacements (old, new) as a dictionary
|
63 |
+
replacements = {
|
64 |
+
'_': ' ',
|
65 |
+
'1f': ' ',
|
66 |
+
'4f': ' ',
|
67 |
+
'-': ' ',
|
68 |
+
'50': ' ',
|
69 |
+
'55': ' ',
|
70 |
+
'56': ' ',
|
71 |
+
|
72 |
+
}
|
73 |
+
instr = step['language_instruction']
|
74 |
+
# instr = clean_task_instruction(instr, replacements)
|
75 |
+
step['observation'] = state
|
76 |
+
step['observation']['natural_language_instruction'] = instr
|
77 |
+
|
78 |
+
return step
|
79 |
+
|
80 |
+
|
81 |
+
if __name__ == "__main__":
|
82 |
+
pass
|
data/preprocess_scripts/nyu_rot_dataset_converted_externally_to_rlds.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, euler_to_quaternion, euler_to_rotation_matrix, \
|
4 |
+
rotation_matrix_to_ortho6d
|
5 |
+
|
6 |
+
def process_step(step: dict) -> dict:
|
7 |
+
"""
|
8 |
+
Unify the action format and clean the task instruction.
|
9 |
+
|
10 |
+
DO NOT use python list, use tf.TensorArray instead.
|
11 |
+
"""
|
12 |
+
# Convert raw action to our action
|
13 |
+
|
14 |
+
origin_action = step['action']
|
15 |
+
step['action']={}
|
16 |
+
action=step['action']
|
17 |
+
action['terminate'] = step['is_terminal']
|
18 |
+
|
19 |
+
eef_delta_pos = origin_action[:3]
|
20 |
+
eef_ang=origin_action[3:6]
|
21 |
+
eef_ang = euler_to_quaternion(eef_ang)
|
22 |
+
# gripper_open: 0-open, 1-closed
|
23 |
+
grip_open=tf.where(tf.equal(origin_action[6:],tf.constant(0.0)),tf.constant(1.0),tf.constant(0.0))
|
24 |
+
|
25 |
+
# No base found
|
26 |
+
|
27 |
+
# Concatenate the action
|
28 |
+
action['arm_concat'] = tf.concat([eef_delta_pos,eef_ang,grip_open],axis=0)
|
29 |
+
|
30 |
+
# Write the action format
|
31 |
+
action['format'] = tf.constant(
|
32 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
|
33 |
+
|
34 |
+
# Convert raw state to our state
|
35 |
+
state = step['observation']
|
36 |
+
eef_pos=state['state'][:3]
|
37 |
+
eef_ang=state['state'][3:6]
|
38 |
+
eef_ang = euler_to_rotation_matrix(eef_ang)
|
39 |
+
eef_ang = rotation_matrix_to_ortho6d(eef_ang)
|
40 |
+
grip_open=1-state['state'][6:7]
|
41 |
+
# Concatenate the state
|
42 |
+
state['arm_concat'] = tf.concat([eef_pos,eef_ang,grip_open],axis=0)
|
43 |
+
|
44 |
+
# Write the state format
|
45 |
+
state['format'] = tf.constant(
|
46 |
+
"eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5,gripper_open")
|
47 |
+
|
48 |
+
# Clean the task instruction
|
49 |
+
# Define the replacements (old, new) as a dictionary
|
50 |
+
replacements = {
|
51 |
+
'_': ' ',
|
52 |
+
'1f': ' ',
|
53 |
+
'4f': ' ',
|
54 |
+
'-': ' ',
|
55 |
+
'50': ' ',
|
56 |
+
'55': ' ',
|
57 |
+
'56': ' ',
|
58 |
+
|
59 |
+
}
|
60 |
+
instr = step['language_instruction']
|
61 |
+
instr = clean_task_instruction(instr, replacements)
|
62 |
+
step['observation']['natural_language_instruction'] = instr
|
63 |
+
|
64 |
+
return step
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
import tensorflow_datasets as tfds
|
69 |
+
from data.utils import dataset_to_path
|
70 |
+
|
71 |
+
DATASET_DIR = 'data/datasets/openx_embod'
|
72 |
+
DATASET_NAME = 'nyu_rot_dataset_converted_externally_to_rlds'
|
73 |
+
# Load the dataset
|
74 |
+
dataset = tfds.builder_from_directory(
|
75 |
+
builder_dir=dataset_to_path(
|
76 |
+
DATASET_NAME, DATASET_DIR))
|
77 |
+
dataset = dataset.as_dataset(split='all')
|
78 |
+
|
79 |
+
# Inspect the dataset
|
80 |
+
for episode in dataset:
|
81 |
+
for step in episode['steps']:
|
82 |
+
print(step)
|
data/preprocess_scripts/robo_net.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from data.utils import clean_task_instruction, euler_to_quaternion, euler_to_rotation_matrix, \
|
5 |
+
rotation_matrix_to_ortho6d
|
6 |
+
|
7 |
+
|
8 |
+
def process_step(step: dict) -> dict:
|
9 |
+
"""
|
10 |
+
Unify the action format and clean the task instruction.
|
11 |
+
|
12 |
+
DO NOT use python list, use tf.TensorArray instead.
|
13 |
+
"""
|
14 |
+
# Convert raw action to our action
|
15 |
+
action = step['action']
|
16 |
+
eef_delta_pos = action[:3]
|
17 |
+
eef_delta_angle_yaw = action[3:4]
|
18 |
+
eef_ang = tf.stack([0.0, 0.0, eef_delta_angle_yaw[0]], axis=0)
|
19 |
+
eef_ang = euler_to_quaternion(eef_ang)
|
20 |
+
eef_gripper_open = (1 - action[4:5]) / 2
|
21 |
+
|
22 |
+
step['action'] = {}
|
23 |
+
action = step['action']
|
24 |
+
action['terminate'] = step['is_terminal']
|
25 |
+
|
26 |
+
# No base found
|
27 |
+
|
28 |
+
# Concatenate the action
|
29 |
+
arm_action = tf.concat([eef_delta_pos, eef_ang, eef_gripper_open], axis=0)
|
30 |
+
action['arm_concat'] = arm_action
|
31 |
+
|
32 |
+
# Write the action format
|
33 |
+
action['format'] = tf.constant(
|
34 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
|
35 |
+
|
36 |
+
# Convert raw state to our state
|
37 |
+
state = step['observation']
|
38 |
+
eef_pos = state['state'][:3]
|
39 |
+
eef_ang_yaw = state['state'][3:4]
|
40 |
+
eef_ang = tf.stack([0.0, 0.0, eef_ang_yaw[0]], axis=0)
|
41 |
+
eef_ang = euler_to_rotation_matrix(eef_ang)
|
42 |
+
eef_ang = rotation_matrix_to_ortho6d(eef_ang)
|
43 |
+
grip_joint_pos = state['state'][4:5]
|
44 |
+
# If abs(grip_joint_pos) > 3.15, then convert it to the radian
|
45 |
+
grip_joint_pos = tf.cond(tf.greater(tf.abs(grip_joint_pos), 3.15),
|
46 |
+
lambda: grip_joint_pos / 180 * np.pi,
|
47 |
+
lambda: grip_joint_pos)
|
48 |
+
# Concatenate the state
|
49 |
+
state['arm_concat'] = tf.concat([eef_pos,eef_ang,grip_joint_pos],axis=0)
|
50 |
+
|
51 |
+
# Write the state format
|
52 |
+
state['format'] = tf.constant(
|
53 |
+
"eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5,gripper_open")
|
54 |
+
|
55 |
+
# Clean the task instruction
|
56 |
+
# Define the replacements (old, new) as a dictionary
|
57 |
+
replacements = {
|
58 |
+
'_': ' ',
|
59 |
+
'1f': ' ',
|
60 |
+
'4f': ' ',
|
61 |
+
'-': ' ',
|
62 |
+
'50': ' ',
|
63 |
+
'55': ' ',
|
64 |
+
'56': ' ',
|
65 |
+
|
66 |
+
}
|
67 |
+
instr = step['language_instruction']
|
68 |
+
instr = clean_task_instruction(instr, replacements)
|
69 |
+
step['observation']['natural_language_instruction'] = instr
|
70 |
+
|
71 |
+
return step
|
data/preprocess_scripts/robomimic_lift_ph.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import tensorflow_datasets as tfds
|
3 |
+
from data.utils import clean_task_instruction, quaternion_to_euler
|
4 |
+
|
5 |
+
|
6 |
+
def load_dataset():
|
7 |
+
builder = tfds.builder('robomimic_ph/lift_ph_image')
|
8 |
+
builder.download_and_prepare()
|
9 |
+
ds = builder.as_dataset(split='train', shuffle_files=True)
|
10 |
+
return ds
|
11 |
+
|
12 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
13 |
+
"""
|
14 |
+
Convert terminate action to a boolean, where True means terminate.
|
15 |
+
"""
|
16 |
+
return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
|
17 |
+
|
18 |
+
def process_step(step: dict) -> dict:
|
19 |
+
"""
|
20 |
+
Unify the action format and clean the task instruction.
|
21 |
+
|
22 |
+
DO NOT use python list, use tf.TensorArray instead.
|
23 |
+
"""
|
24 |
+
# format refers to https://www.tensorflow.org/datasets/catalog/robomimic_mg
|
25 |
+
# Convert raw action to our action
|
26 |
+
eef = step['action']
|
27 |
+
step['action'] = {}
|
28 |
+
action = step['action']
|
29 |
+
action['terminate'] = step['is_terminal']
|
30 |
+
|
31 |
+
eef_delta_pos = eef[:3]
|
32 |
+
eef_ang = quaternion_to_euler(eef[3:])
|
33 |
+
|
34 |
+
# No base found
|
35 |
+
|
36 |
+
# Concatenate the action
|
37 |
+
arm_action = tf.concat([eef_delta_pos, eef_ang], axis=0)
|
38 |
+
action['arm_concat'] = arm_action
|
39 |
+
|
40 |
+
# Write the action format
|
41 |
+
action['format'] = tf.constant(
|
42 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_roll,eef_delta_angle_pitch,eef_delta_angle_yaw")
|
43 |
+
|
44 |
+
# Convert raw state to our state
|
45 |
+
state = step['observation']
|
46 |
+
arm_joint_pos = state['robot0_joint_pos']
|
47 |
+
arm_joint_vel = state['robot0_joint_vel']
|
48 |
+
gripper_pos = state['robot0_gripper_qpos']
|
49 |
+
gripper_vel = state['robot0_gripper_qvel']
|
50 |
+
eef_pos = state['robot0_eef_pos']
|
51 |
+
eef_ang = quaternion_to_euler(state['robot0_eef_quat'])
|
52 |
+
|
53 |
+
state['arm_concat'] = tf.concat([arm_joint_pos, arm_joint_vel, gripper_pos,gripper_vel,eef_pos,eef_ang], axis=0)
|
54 |
+
# convert to tf32
|
55 |
+
state['arm_concat'] = tf.cast(state['arm_concat'], tf.float32)
|
56 |
+
# Write the state format
|
57 |
+
state['format'] = tf.constant(
|
58 |
+
"arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,arm_joint_0_vel,arm_joint_1_vel,arm_joint_2_vel,arm_joint_3_vel,arm_joint_4_vel,arm_joint_5_vel,arm_joint_6_vel,gripper_joint_0_pos,gripper_joint_1_pos,gripper_joint_0_vel,gripper_joint_1_vel,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_roll,eef_angle_pitch,eef_angle_yaw")
|
59 |
+
|
60 |
+
# Clean the task instruction
|
61 |
+
# Define the replacements (old, new) as a dictionary
|
62 |
+
replacements = {
|
63 |
+
'_': ' ',
|
64 |
+
'1f': ' ',
|
65 |
+
'4f': ' ',
|
66 |
+
'-': ' ',
|
67 |
+
'50': ' ',
|
68 |
+
'55': ' ',
|
69 |
+
'56': ' ',
|
70 |
+
|
71 |
+
}
|
72 |
+
# manual added by lbg
|
73 |
+
instr = "lift the object on the table"
|
74 |
+
instr = clean_task_instruction(instr, replacements)
|
75 |
+
step['observation']['natural_language_instruction'] = instr
|
76 |
+
|
77 |
+
return step
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
import tensorflow_datasets as tfds
|
82 |
+
from data.utils import dataset_to_path
|
83 |
+
|
84 |
+
DATASET_DIR = 'data/datasets/openx_embod'
|
85 |
+
DATASET_NAME = 'roboturk'
|
86 |
+
# Load the dataset
|
87 |
+
dataset = tfds.builder_from_directory(
|
88 |
+
builder_dir=dataset_to_path(
|
89 |
+
DATASET_NAME, DATASET_DIR))
|
90 |
+
dataset = dataset.as_dataset(split='all').take(1)
|
91 |
+
|
92 |
+
# Inspect the dataset
|
93 |
+
ze=tf.constant(0.0)
|
94 |
+
for episode in dataset:
|
95 |
+
for step in episode['steps']:
|
96 |
+
print(step)
|
97 |
+
break
|
data/preprocess_scripts/robomimic_square_ph.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import tensorflow_datasets as tfds
|
3 |
+
from data.utils import clean_task_instruction, quaternion_to_euler
|
4 |
+
|
5 |
+
|
6 |
+
def load_dataset():
|
7 |
+
builder = tfds.builder('robomimic_ph/square_ph_image')
|
8 |
+
builder.download_and_prepare()
|
9 |
+
ds = builder.as_dataset(split='train', shuffle_files=True)
|
10 |
+
return ds
|
11 |
+
|
12 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
13 |
+
"""
|
14 |
+
Convert terminate action to a boolean, where True means terminate.
|
15 |
+
"""
|
16 |
+
return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
|
17 |
+
|
18 |
+
def process_step(step: dict) -> dict:
|
19 |
+
"""
|
20 |
+
Unify the action format and clean the task instruction.
|
21 |
+
|
22 |
+
DO NOT use python list, use tf.TensorArray instead.
|
23 |
+
"""
|
24 |
+
# format refers to https://www.tensorflow.org/datasets/catalog/robomimic_mg
|
25 |
+
# Convert raw action to our action
|
26 |
+
eef = step['action']
|
27 |
+
step['action'] = {}
|
28 |
+
action = step['action']
|
29 |
+
action['terminate'] = step['is_terminal']
|
30 |
+
|
31 |
+
eef_delta_pos = eef[:3]
|
32 |
+
eef_ang = quaternion_to_euler(eef[3:])
|
33 |
+
|
34 |
+
# No base found
|
35 |
+
|
36 |
+
# Concatenate the action
|
37 |
+
arm_action = tf.concat([eef_delta_pos, eef_ang], axis=0)
|
38 |
+
action['arm_concat'] = arm_action
|
39 |
+
|
40 |
+
# Write the action format
|
41 |
+
action['format'] = tf.constant(
|
42 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_roll,eef_delta_angle_pitch,eef_delta_angle_yaw")
|
43 |
+
|
44 |
+
# Convert raw state to our state
|
45 |
+
state = step['observation']
|
46 |
+
arm_joint_pos = state['robot0_joint_pos']
|
47 |
+
arm_joint_vel = state['robot0_joint_vel']
|
48 |
+
gripper_pos = state['robot0_gripper_qpos']
|
49 |
+
gripper_vel = state['robot0_gripper_qvel']
|
50 |
+
eef_pos = state['robot0_eef_pos']
|
51 |
+
eef_ang = quaternion_to_euler(state['robot0_eef_quat'])
|
52 |
+
|
53 |
+
state['arm_concat'] = tf.concat([arm_joint_pos, arm_joint_vel, gripper_pos,gripper_vel,eef_pos,eef_ang], axis=0)
|
54 |
+
# convert to tf32
|
55 |
+
state['arm_concat'] = tf.cast(state['arm_concat'], tf.float32)
|
56 |
+
# Write the state format
|
57 |
+
state['format'] = tf.constant(
|
58 |
+
"arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,arm_joint_0_vel,arm_joint_1_vel,arm_joint_2_vel,arm_joint_3_vel,arm_joint_4_vel,arm_joint_5_vel,arm_joint_6_vel,gripper_joint_0_pos,gripper_joint_1_pos,gripper_joint_0_vel,gripper_joint_1_vel,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_roll,eef_angle_pitch,eef_angle_yaw")
|
59 |
+
|
60 |
+
# Clean the task instruction
|
61 |
+
# Define the replacements (old, new) as a dictionary
|
62 |
+
replacements = {
|
63 |
+
'_': ' ',
|
64 |
+
'1f': ' ',
|
65 |
+
'4f': ' ',
|
66 |
+
'-': ' ',
|
67 |
+
'50': ' ',
|
68 |
+
'55': ' ',
|
69 |
+
'56': ' ',
|
70 |
+
|
71 |
+
}
|
72 |
+
# manual added by lbg
|
73 |
+
instr = "move the square across the cube"
|
74 |
+
instr = clean_task_instruction(instr, replacements)
|
75 |
+
step['observation']['natural_language_instruction'] = instr
|
76 |
+
|
77 |
+
return step
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
import tensorflow_datasets as tfds
|
82 |
+
from data.utils import dataset_to_path
|
83 |
+
|
84 |
+
DATASET_DIR = 'data/datasets/openx_embod'
|
85 |
+
DATASET_NAME = 'roboturk'
|
86 |
+
# Load the dataset
|
87 |
+
dataset = tfds.builder_from_directory(
|
88 |
+
builder_dir=dataset_to_path(
|
89 |
+
DATASET_NAME, DATASET_DIR))
|
90 |
+
dataset = dataset.as_dataset(split='all').take(1)
|
91 |
+
|
92 |
+
# Inspect the dataset
|
93 |
+
ze=tf.constant(0.0)
|
94 |
+
for episode in dataset:
|
95 |
+
for step in episode['steps']:
|
96 |
+
print(step)
|
97 |
+
break
|
data/preprocess_scripts/roboset.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import tensorflow_datasets as tfds
|
3 |
+
from data.utils import clean_task_instruction, quaternion_to_euler
|
4 |
+
import tensorflow as tf
|
5 |
+
import h5py
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
import os
|
9 |
+
import imageio
|
10 |
+
import concurrent.futures
|
11 |
+
import fnmatch
|
12 |
+
import cv2
|
13 |
+
import random
|
14 |
+
path2json = {
|
15 |
+
"mnt/raid5/data/jaydv/robohive_base/episodes/franka-FrankaPlanarPushReal_v2d/set_17_plannar_push_eval": [
|
16 |
+
"Push the green object to the red line."
|
17 |
+
],
|
18 |
+
"mnt/raid5/data/jaydv/robohive_base/episodes/franka-FrankaBinReorientRealRP03_v2d_set16/set_16_bin_reorient_3": [
|
19 |
+
"Pick up the bottle and and stand it upright."
|
20 |
+
],
|
21 |
+
"mnt/raid5/data/jaydv/robohive_base/demonstrations/orange_block/set_7_orange_block": [
|
22 |
+
"Pick up the orange block."
|
23 |
+
],
|
24 |
+
"mnt/raid5/data/jaydv/robohive_base/demonstrations/wooden_block": [
|
25 |
+
"Pick up the wooden block."
|
26 |
+
],
|
27 |
+
"mnt/raid5/data/jaydv/robohive_base/episodes/franka-FrankaBinReorientReal_v2d-orig/set_15_bin_reorient_eval_2": [
|
28 |
+
"Pick up the bottle and and stand it upright."
|
29 |
+
],
|
30 |
+
"mnt/raid5/data/jaydv/robohive_base/episodes/franka-FrankaBinPushReal_v2d_set14/set_14_bin_push": [
|
31 |
+
"Push the object to the red line."
|
32 |
+
],
|
33 |
+
"mnt/raid5/data/jaydv/robohive_base/episodes/franka-FrankaBinPickRealRP05_v2d/set_11_bottle_pick": [
|
34 |
+
"Pick up the bottle."
|
35 |
+
],
|
36 |
+
"mnt/raid5/data/jaydv/robohive_base/episodes/franka-FrankaBinPickRealRP03_v2d/set_10_bin_pick_2004": [
|
37 |
+
"Pick up the wooden block."
|
38 |
+
],
|
39 |
+
"mnt/raid5/data/jaydv/bin_pick_data_30029/set_2_softtoys": [
|
40 |
+
"Pick up one toy in the basket."
|
41 |
+
],
|
42 |
+
"home/jaydv/Documents/RoboSet/pick_banana_from_toaster_place_on_table_data": [
|
43 |
+
"Pick banana from toaster and place on table."
|
44 |
+
],
|
45 |
+
"mnt/raid5/data/jaydv/robohive_base/episodes/franka-FrankaBinPickReal_v2d/set_12_bin_pick_eval": [
|
46 |
+
"Pick up the block."
|
47 |
+
],
|
48 |
+
"mnt/raid5/data/roboset/v0.3/flap_open_toaster_oven_data": [
|
49 |
+
"Flap open toaster."
|
50 |
+
],
|
51 |
+
"mnt/raid5/data/jaydv/robohive_base/episodes/franka-FrankaBinReorientReal_v2d_set13/set_13_bin_reorient": [
|
52 |
+
"Pick up the bottle and stand it upright."
|
53 |
+
],
|
54 |
+
"home/jaydv/Documents/RoboSet/drag_mug_from_right_to_left_data": [
|
55 |
+
"Drag mug right to left."
|
56 |
+
],
|
57 |
+
"home/jaydv/Documents/RoboSet/drag_strainer_backward_data": [
|
58 |
+
"Drag strainer backwards."
|
59 |
+
],
|
60 |
+
"mnt/raid5/data/roboset/v0.4/baking_prep/scene_4/baking_prep_slide_close_drawer_scene_4": [
|
61 |
+
"Slide and close the drawer."
|
62 |
+
],
|
63 |
+
"mnt/raid5/data/roboset/v0.4/heat_soup/scene_4/baking_slide_in_bowl_scene_4": [
|
64 |
+
"Place the bowl into the container."
|
65 |
+
],
|
66 |
+
"set_5_bottle_cube_14": [
|
67 |
+
"Pick up the bottle."
|
68 |
+
],
|
69 |
+
"home/jaydv/Documents/RoboSet/drag_mug_forward_data": [
|
70 |
+
"Drag mug forwards."
|
71 |
+
],
|
72 |
+
"mnt/raid5/data/roboset/v0.4/clean_kitchen/scene_3/clean_kitchen_pick_towel_scene_3": [
|
73 |
+
"Pick up the towel from the oven."
|
74 |
+
],
|
75 |
+
"mnt/raid5/data/roboset/v0.4/heat_soup/scene_2/baking_pick_bowl_scene_2": [
|
76 |
+
"Pick up the bowl."
|
77 |
+
],
|
78 |
+
"mnt/raid5/data/roboset/v0.4/heat_soup/scene_2/baking_slide_in_bowl_scene_2": [
|
79 |
+
"Place the bowl into the oven."
|
80 |
+
],
|
81 |
+
"mnt/raid5/data/roboset/v0.4/heat_soup/scene_2/baking_close_oven_scene_2": [
|
82 |
+
"Flap and close the oven."
|
83 |
+
],
|
84 |
+
"home/jaydv/Documents/RoboSet/pick_banana_place_in_strainer_data": [
|
85 |
+
"Pick banana and place it in strainer."
|
86 |
+
],
|
87 |
+
"mnt/raid5/data/roboset/v0.3/pick_banana_place_in_mug_data": [
|
88 |
+
"Pick banana and place it in mug."
|
89 |
+
],
|
90 |
+
"mnt/raid5/data/roboset/v0.4/make_tea/scene_2/make_tea_pick_tea_scene_2": [
|
91 |
+
"Pick up the tea from the container."
|
92 |
+
],
|
93 |
+
"home/jaydv/Documents/RoboSet/drag_strainer_forward_data": [
|
94 |
+
"Drag strainer forwards."
|
95 |
+
],
|
96 |
+
"mnt/raid5/data/roboset/v0.3/pick_ketchup_from_strainer_place_on_table_data": [
|
97 |
+
"Pick ketchup from strainer and place it on the table."
|
98 |
+
],
|
99 |
+
"mnt/raid5/data/roboset/v0.4/baking_prep/scene_4/baking_prep_place_butter_scene_4": [
|
100 |
+
"Place the butter on the cutting board."
|
101 |
+
],
|
102 |
+
"mnt/raid5/data/roboset/v0.4/baking_prep/scene_1/baking_prep_slide_open_drawer_scene_1": [
|
103 |
+
"Slide and open the drawer."
|
104 |
+
],
|
105 |
+
"mnt/raid5/data/roboset/v0.3/drag_strainer_right_to_left_data": [
|
106 |
+
"Drag strainer right to left."
|
107 |
+
],
|
108 |
+
"home/jaydv/Documents/RoboSet/pick_banana_from_plate_place_on_table_data": [
|
109 |
+
"Pick banana from plate and place on table."
|
110 |
+
],
|
111 |
+
"mnt/raid5/data/roboset/v0.3/pick_ketchup_from_plate_place_on_table_data": [
|
112 |
+
"Pick ketchup from plate and place it on table."
|
113 |
+
],
|
114 |
+
"mnt/raid5/data/roboset/v0.4/baking_prep/scene_1/baking_prep_slide_close_drawer_scene_1": [
|
115 |
+
"Slide and close the drawer."
|
116 |
+
],
|
117 |
+
"mnt/raid5/data/roboset/v0.3/drag_strainer_left_to_right_data": [
|
118 |
+
"Drag strainer left to right."
|
119 |
+
],
|
120 |
+
"pick_ketchup_place_on_toaster_data": [
|
121 |
+
"Pick ketchup from table and place on toaster."
|
122 |
+
],
|
123 |
+
"mnt/raid5/data/roboset/v0.4/make_tea/scene_2/make_tea_place_tea_scene_2": [
|
124 |
+
"Place the tea into the cup."
|
125 |
+
],
|
126 |
+
"mnt/raid5/data/roboset/v0.3/pick_ketchup_place_in_strainer_data": [
|
127 |
+
"Pick ketchup from the table and place it in strainer."
|
128 |
+
],
|
129 |
+
"home/jaydv/Documents/RoboSet/pick_ketchup_place_on_plate_data": [
|
130 |
+
"Pick ketchup from table and place on plate."
|
131 |
+
],
|
132 |
+
"home/jaydv/Documents/RoboSet/drag_mug_backward_data": [
|
133 |
+
"Drag mug backwards."
|
134 |
+
],
|
135 |
+
"set_1_blocks_897": [
|
136 |
+
"Pick up one block in the basket."
|
137 |
+
],
|
138 |
+
"mnt/raid5/data/roboset/v0.4/make_tea/scene_2/make_tea_place_lid_scene_2": [
|
139 |
+
"Place lid on the cutting board."
|
140 |
+
],
|
141 |
+
"mnt/raid5/data/roboset/v0.4/heat_soup/scene_4/baking_pick_bowl_scene_4": [
|
142 |
+
"Pick up the bowl."
|
143 |
+
],
|
144 |
+
"mnt/raid5/data/roboset/v0.4/baking_prep/scene_4/baking_prep_pick_butter_scene_4": [
|
145 |
+
"Pick up the butter from the drawer."
|
146 |
+
],
|
147 |
+
"mnt/raid5/data/roboset/v0.4/baking_prep/scene_1/baking_prep_pick_butter_scene_1": [
|
148 |
+
"Pick up the butter from the drawer."
|
149 |
+
],
|
150 |
+
"home/jaydv/Documents/RoboSet/flap_close_toaster_oven_data": [
|
151 |
+
"Flap close toaster."
|
152 |
+
],
|
153 |
+
"home/jaydv/Documents/RoboSet/drag_mug_from_left_to_right_data": [
|
154 |
+
"Drag mug left to right."
|
155 |
+
],
|
156 |
+
"set_6_planar_push_120": [
|
157 |
+
"Push the object from left to right."
|
158 |
+
],
|
159 |
+
"mnt/raid5/data/roboset/v0.4/clean_kitchen/scene_3/clean_kitchen_slide_close_drawer_scene_3": [
|
160 |
+
"Slide and close the drawer."
|
161 |
+
],
|
162 |
+
"set_4_med_block_7": [
|
163 |
+
"Pick up the wooden block."
|
164 |
+
],
|
165 |
+
"mnt/raid5/data/roboset/v0.3/pick_banana_place_on_toaster_data": [
|
166 |
+
"Pick banana from table and place on toaster."
|
167 |
+
],
|
168 |
+
"mnt/raid5/data/roboset/v0.3/pick_ketchup_from_toaster_place_on_table_data": [
|
169 |
+
"Pick ketchup from toaster and place it on table."
|
170 |
+
],
|
171 |
+
"mnt/raid5/data/roboset/v0.4/baking_prep/scene_1/baking_prep_place_butter_scene_1": [
|
172 |
+
"Place the butter on the cutting board."
|
173 |
+
],
|
174 |
+
"mnt/raid5/data/roboset/v0.4/baking_prep/scene_4/baking_prep_slide_open_drawer_scene_4": [
|
175 |
+
"Slide and open the drawer."
|
176 |
+
],
|
177 |
+
"home/jaydv/Documents/RoboSet/pick_banana_place_on_plate_data": [
|
178 |
+
"Pick banana from table and place on plate."
|
179 |
+
],
|
180 |
+
"set_8_pick_bottle_10": [
|
181 |
+
"Pick up the bottle."
|
182 |
+
],
|
183 |
+
"home/jaydv/Documents/RoboSet/pick_ketchup_place_in_toaster_data": [
|
184 |
+
"Pick ketchup from the table and place in toaster."
|
185 |
+
]
|
186 |
+
}
|
187 |
+
|
188 |
+
image_shape = (240, 424, 3)
|
189 |
+
Dmanus = ['']
|
190 |
+
def stash_image_into_observation(step):
|
191 |
+
step['observation'] = {'cam_high': [], 'cam_left_wrist': [], 'cam_right_wrist':[]}
|
192 |
+
step['observation']['cam_high'] = step['cam_high']
|
193 |
+
step['observation']['cam_left_wrist'] = step['cam_left_wrist']
|
194 |
+
step['observation']['cam_right_wrist'] = step['cam_right_wrist']
|
195 |
+
return step
|
196 |
+
|
197 |
+
def _parse_function(proto,instruction):
|
198 |
+
# Update the keys_to_features dictionary to match the new TFRecord format
|
199 |
+
keys_to_features = {
|
200 |
+
'action': tf.io.FixedLenFeature([], tf.string),
|
201 |
+
'action_gripper': tf.io.FixedLenFeature([], tf.string),
|
202 |
+
'qpos': tf.io.FixedLenFeature([], tf.string),
|
203 |
+
'qvel': tf.io.FixedLenFeature([], tf.string),
|
204 |
+
'qpos_gripper': tf.io.FixedLenFeature([], tf.string),
|
205 |
+
'qvel_gripper': tf.io.FixedLenFeature([], tf.string),
|
206 |
+
'rgb_left': tf.io.FixedLenFeature([], tf.string),
|
207 |
+
'rgb_right': tf.io.FixedLenFeature([], tf.string),
|
208 |
+
'rgb_top': tf.io.FixedLenFeature([], tf.string),
|
209 |
+
'terminate_episode': tf.io.FixedLenFeature([], tf.int64)
|
210 |
+
}
|
211 |
+
|
212 |
+
# Parse the incoming features according to the dictionary
|
213 |
+
parsed_features = tf.io.parse_single_example(proto, keys_to_features)
|
214 |
+
|
215 |
+
# Deserialize and reshape tensors as necessary
|
216 |
+
action = tf.io.parse_tensor(parsed_features['action'], out_type=tf.float16)
|
217 |
+
action_gripper = tf.io.parse_tensor(parsed_features['action_gripper'], out_type=tf.float16)
|
218 |
+
qpos = tf.io.parse_tensor(parsed_features['qpos'], out_type=tf.float16)
|
219 |
+
qvel = tf.io.parse_tensor(parsed_features['qvel'], out_type=tf.float16)
|
220 |
+
qpos_gripper = tf.io.parse_tensor(parsed_features['qpos_gripper'], out_type=tf.float16)
|
221 |
+
qvel_gripper = tf.io.parse_tensor(parsed_features['qvel_gripper'], out_type=tf.float16)
|
222 |
+
rgb_left = tf.io.parse_tensor(parsed_features['rgb_left'], out_type=tf.uint8)
|
223 |
+
rgb_right = tf.io.parse_tensor(parsed_features['rgb_right'], out_type=tf.uint8)
|
224 |
+
rgb_top = tf.io.parse_tensor(parsed_features['rgb_top'], out_type=tf.uint8)
|
225 |
+
terminate_episode = tf.cast(parsed_features['terminate_episode'], tf.int64)
|
226 |
+
|
227 |
+
# Reshape or modify other fields as needed to fit the model input
|
228 |
+
rgb_left = tf.reshape(rgb_left, image_shape)
|
229 |
+
rgb_right = tf.reshape(rgb_right, image_shape)
|
230 |
+
rgb_top = tf.reshape(rgb_top, image_shape)
|
231 |
+
|
232 |
+
return {
|
233 |
+
"action": action,
|
234 |
+
"action_gripper": action_gripper,
|
235 |
+
"qpos": qpos,
|
236 |
+
"qvel": qvel,
|
237 |
+
"qpos_gripper": qpos_gripper,
|
238 |
+
"qvel_gripper": qvel_gripper,
|
239 |
+
"observation": {
|
240 |
+
"rgb_left": rgb_left,
|
241 |
+
"rgb_right": rgb_right,
|
242 |
+
"rgb_top": rgb_top
|
243 |
+
},
|
244 |
+
"terminate_episode": terminate_episode,
|
245 |
+
"instruction": instruction
|
246 |
+
}
|
247 |
+
|
248 |
+
|
249 |
+
def dataset_generator_from_tfrecords(seed):
|
250 |
+
tfrecord_path = './data/datasets/roboset/tfrecords/'
|
251 |
+
failure = [f'set_{i}' for i in range(10, 18)]
|
252 |
+
filepaths = []
|
253 |
+
for root, dirs, files in os.walk(tfrecord_path):
|
254 |
+
# skip datasets with failure
|
255 |
+
fail = False
|
256 |
+
for f in failure:
|
257 |
+
if f in root:
|
258 |
+
fail = True
|
259 |
+
break
|
260 |
+
if fail:
|
261 |
+
continue
|
262 |
+
|
263 |
+
for filename in fnmatch.filter(files, '*.tfrecord'):
|
264 |
+
filepath = os.path.join(root, filename)
|
265 |
+
filepaths.append(filepath)
|
266 |
+
|
267 |
+
random.seed(seed)
|
268 |
+
random.shuffle(filepaths)
|
269 |
+
for filepath in filepaths:
|
270 |
+
for path in path2json:
|
271 |
+
if path in filepath:
|
272 |
+
instruction = path2json[path]
|
273 |
+
raw_dataset = tf.data.TFRecordDataset(filepath)
|
274 |
+
dataset = raw_dataset.map(lambda x: _parse_function(x,instruction))
|
275 |
+
yield {
|
276 |
+
'steps': dataset
|
277 |
+
}
|
278 |
+
|
279 |
+
def load_dataset(seed):
|
280 |
+
dataset = tf.data.Dataset.from_generator(
|
281 |
+
lambda: dataset_generator_from_tfrecords(seed),
|
282 |
+
output_signature={
|
283 |
+
'steps': tf.data.DatasetSpec(
|
284 |
+
element_spec={
|
285 |
+
'action': tf.TensorSpec(shape=(None), dtype=tf.float16),
|
286 |
+
'action_gripper': tf.TensorSpec(shape=(None), dtype=tf.float16),
|
287 |
+
'qpos': tf.TensorSpec(shape=(None), dtype=tf.float16),
|
288 |
+
'qvel': tf.TensorSpec(shape=(None), dtype=tf.float16),
|
289 |
+
'qpos_gripper': tf.TensorSpec(shape=(None), dtype=tf.float16),
|
290 |
+
'qvel_gripper': tf.TensorSpec(shape=(None), dtype=tf.float16),
|
291 |
+
'observation': {
|
292 |
+
'rgb_left': tf.TensorSpec(shape=image_shape, dtype=tf.uint8),
|
293 |
+
'rgb_right': tf.TensorSpec(shape=image_shape, dtype=tf.uint8),
|
294 |
+
'rgb_top': tf.TensorSpec(shape=image_shape, dtype=tf.uint8),
|
295 |
+
},
|
296 |
+
'terminate_episode': tf.TensorSpec(shape=(), dtype=tf.int64),
|
297 |
+
'instruction': tf.TensorSpec(shape=(None), dtype=tf.string)
|
298 |
+
}
|
299 |
+
)
|
300 |
+
}
|
301 |
+
)
|
302 |
+
|
303 |
+
return dataset
|
304 |
+
|
305 |
+
|
306 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
307 |
+
"""
|
308 |
+
Convert terminate action to a boolean, where True means terminate.
|
309 |
+
"""
|
310 |
+
return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float16)),tf.constant(False),tf.constant(True))
|
311 |
+
|
312 |
+
|
313 |
+
def process_step(step: dict) -> dict:
|
314 |
+
"""
|
315 |
+
Unify the action format and clean the task instruction.
|
316 |
+
|
317 |
+
DO NOT use python list, use tf.TensorArray instead.
|
318 |
+
"""
|
319 |
+
# Convert raw action to our action
|
320 |
+
step['action'] = {}
|
321 |
+
step['action']['terminate'] = step['terminate_episode']
|
322 |
+
# undetermined action
|
323 |
+
|
324 |
+
state = step['observation']
|
325 |
+
qpos = tf.cast(step['qpos'], tf.float32)
|
326 |
+
# qvel = tf.cast(step['qvel'], tf.float32)
|
327 |
+
gripper_pos = tf.expand_dims(tf.cast(step['qpos_gripper'], tf.float32), axis=0)
|
328 |
+
# delete due to all zeros
|
329 |
+
# gripper_vel = tf.expand_dims(tf.cast(step['qvel_gripper'], tf.float32), axis=0)
|
330 |
+
|
331 |
+
# state['arm_concat'] = tf.concat([qpos, qvel, gripper_pos, gripper_vel], axis=0)
|
332 |
+
state['arm_concat'] = tf.concat([qpos, gripper_pos], axis=0)
|
333 |
+
# state['format'] = tf.constant(
|
334 |
+
# "arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,arm_joint_0_vel,arm_joint_1_vel,arm_joint_2_vel,arm_joint_3_vel,arm_joint_4_vel,arm_joint_5_vel,arm_joint_6_vel,gripper_joint_0_pos,gripper_joint_0_vel"
|
335 |
+
# )
|
336 |
+
state['format'] = tf.constant(
|
337 |
+
"arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,gripper_joint_0_pos"
|
338 |
+
)
|
339 |
+
# Clean the task instruction
|
340 |
+
# Define the replacements (old, new) as a dictionary
|
341 |
+
replacements = {
|
342 |
+
'_': ' ',
|
343 |
+
'1f': ' ',
|
344 |
+
'4f': ' ',
|
345 |
+
'-': ' ',
|
346 |
+
'50': ' ',
|
347 |
+
'55': ' ',
|
348 |
+
'56': ' ',
|
349 |
+
|
350 |
+
}
|
351 |
+
instr = step['instruction'][0]
|
352 |
+
instr = clean_task_instruction(instr, replacements)
|
353 |
+
step['observation']['natural_language_instruction'] = instr
|
354 |
+
|
355 |
+
return step
|
356 |
+
|
357 |
+
|
358 |
+
if __name__ == "__main__":
|
359 |
+
import tensorflow_datasets as tfds
|
360 |
+
from data.utils import dataset_to_path
|
361 |
+
|
362 |
+
dataset = load_dataset()
|
363 |
+
for step in dataset.take(100):
|
364 |
+
for data in step['steps']:
|
365 |
+
data = process_step(data)
|
366 |
+
print(data)
|
367 |
+
break
|
data/preprocess_scripts/roboturk.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, quaternion_to_euler, euler_to_quaternion
|
4 |
+
|
5 |
+
|
6 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
7 |
+
"""
|
8 |
+
Convert terminate action to a boolean, where True means terminate.
|
9 |
+
"""
|
10 |
+
return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
|
11 |
+
|
12 |
+
|
13 |
+
def process_step(step: dict) -> dict:
|
14 |
+
"""
|
15 |
+
Unify the action format and clean the task instruction.
|
16 |
+
|
17 |
+
DO NOT use python list, use tf.TensorArray instead.
|
18 |
+
"""
|
19 |
+
# Convert raw action to our action
|
20 |
+
action = step['action']
|
21 |
+
action['terminate'] = terminate_act_to_bool(action['terminate_episode'])
|
22 |
+
|
23 |
+
eef_delta_pos = action['world_vector']
|
24 |
+
eef_ang = action['rotation_delta']
|
25 |
+
eef_ang = euler_to_quaternion(eef_ang)
|
26 |
+
|
27 |
+
grip_open = tf.where(action['gripper_closedness_action']<0,tf.constant(1.0),tf.constant(0.0))
|
28 |
+
|
29 |
+
# No base found
|
30 |
+
|
31 |
+
# Concatenate the action
|
32 |
+
arm_action = tf.concat([eef_delta_pos, eef_ang, grip_open], axis=0)
|
33 |
+
action['arm_concat'] = arm_action
|
34 |
+
|
35 |
+
# Write the action format
|
36 |
+
action['format'] = tf.constant(
|
37 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
|
38 |
+
|
39 |
+
# No state found
|
40 |
+
|
41 |
+
# Clean the task instruction
|
42 |
+
# Define the replacements (old, new) as a dictionary
|
43 |
+
replacements = {
|
44 |
+
'_': ' ',
|
45 |
+
'1f': ' ',
|
46 |
+
'4f': ' ',
|
47 |
+
'-': ' ',
|
48 |
+
'50': ' ',
|
49 |
+
'55': ' ',
|
50 |
+
'56': ' ',
|
51 |
+
|
52 |
+
}
|
53 |
+
instr = step['observation']['natural_language_instruction']
|
54 |
+
instr = clean_task_instruction(instr, replacements)
|
55 |
+
step['observation']['natural_language_instruction'] = instr
|
56 |
+
|
57 |
+
return step
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
import tensorflow_datasets as tfds
|
62 |
+
from data.utils import dataset_to_path
|
63 |
+
|
64 |
+
DATASET_DIR = 'data/datasets/openx_embod'
|
65 |
+
DATASET_NAME = 'roboturk'
|
66 |
+
# Load the dataset
|
67 |
+
dataset = tfds.builder_from_directory(
|
68 |
+
builder_dir=dataset_to_path(
|
69 |
+
DATASET_NAME, DATASET_DIR))
|
70 |
+
dataset = dataset.as_dataset(split='all').take(1)
|
71 |
+
|
72 |
+
# Inspect the dataset
|
73 |
+
ze=tf.constant(0.0)
|
74 |
+
for episode in dataset:
|
75 |
+
for step in episode['steps']:
|
76 |
+
print(step)
|
77 |
+
break
|
data/preprocess_scripts/roboturk_real_objectsearch.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import tensorflow_datasets as tfds
|
3 |
+
from data.utils import clean_task_instruction, quaternion_to_euler
|
4 |
+
import tensorflow as tf
|
5 |
+
import h5py
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
import os
|
9 |
+
import imageio
|
10 |
+
import concurrent.futures
|
11 |
+
|
12 |
+
def get_frames(file_path):
|
13 |
+
if not os.path.exists(file_path) or not os.path.isfile(file_path) or not file_path.endswith('.mp4'):
|
14 |
+
return []
|
15 |
+
frames = []
|
16 |
+
with imageio.get_reader(file_path, 'ffmpeg') as reader:
|
17 |
+
for frame in reader:
|
18 |
+
frame = np.array(frame, dtype=np.uint8)
|
19 |
+
frames.append(frame)
|
20 |
+
return frames
|
21 |
+
|
22 |
+
def parallel_get_frames(paths):
|
23 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
24 |
+
future_to_path = {executor.submit(get_frames, path): path for path in paths}
|
25 |
+
return [future.result() for future in concurrent.futures.as_completed(future_to_path)]
|
26 |
+
|
27 |
+
def count_total_samples(filename):
|
28 |
+
total_samples = 0
|
29 |
+
with h5py.File(filename, 'r') as f:
|
30 |
+
data = f['data']
|
31 |
+
for user_key in data.keys():
|
32 |
+
user = data[user_key]
|
33 |
+
for demo_key in user.keys():
|
34 |
+
total_samples += 1
|
35 |
+
return total_samples
|
36 |
+
|
37 |
+
def dataset_generator(filename, total_samples):
|
38 |
+
with h5py.File(filename, 'r') as f:
|
39 |
+
data = f['data']
|
40 |
+
for user_key in data.keys():
|
41 |
+
user = data[user_key]
|
42 |
+
for demo_key in user.keys():
|
43 |
+
demo = user[demo_key]
|
44 |
+
robot_observation = demo['robot_observation']
|
45 |
+
user_control = demo['user_control']
|
46 |
+
|
47 |
+
eef_poses = robot_observation['eef_poses']
|
48 |
+
joint_states_arm = robot_observation['joint_states_arm']
|
49 |
+
joint_states_gripper = robot_observation['joint_states_gripper']
|
50 |
+
user_control_data = user_control['user_control']
|
51 |
+
|
52 |
+
attrs = dict(demo.attrs)
|
53 |
+
top_depth_video_file = attrs['top_depth_video_file']
|
54 |
+
top_rgb_video_file = attrs['top_rgb_video_file']
|
55 |
+
front_rgb_video_file = attrs['front_rgb_video_file']
|
56 |
+
|
57 |
+
video_root_path = './data/datasets/roboturk/'
|
58 |
+
top_depth_frames = get_frames(os.path.join(video_root_path, top_depth_video_file))
|
59 |
+
top_rgb_frames = get_frames(os.path.join(video_root_path, top_rgb_video_file))
|
60 |
+
front_rgb_frames = get_frames(os.path.join(video_root_path, front_rgb_video_file))
|
61 |
+
|
62 |
+
if len(top_rgb_frames) == 0 or len(front_rgb_frames) == 0:
|
63 |
+
continue
|
64 |
+
|
65 |
+
steps = []
|
66 |
+
for i in range(len(eef_poses)):
|
67 |
+
task_demo_id = f"SawyerTowerCreation_{demo_key}_{i}"
|
68 |
+
step = {
|
69 |
+
'task_demo_id': task_demo_id,
|
70 |
+
'eef_poses': eef_poses[i],
|
71 |
+
'joint_states_arm': joint_states_arm[i],
|
72 |
+
'joint_states_gripper': joint_states_gripper[i],
|
73 |
+
'user_control': user_control_data[i] if user_control_data.shape[0] > 0 else np.zeros(22),
|
74 |
+
'observation':{
|
75 |
+
'top_depth_frame': top_depth_frames[i] if i < len(top_depth_frames) else np.zeros((0,0, 3), dtype=np.uint8),
|
76 |
+
'top_rgb_frame': top_rgb_frames[i] if i < len(top_rgb_frames) else np.zeros((0, 0, 3), dtype=np.uint8),
|
77 |
+
'front_rgb_frame': front_rgb_frames[i] if i < len(front_rgb_frames) else np.zeros((0, 0, 3), dtype=np.uint8),
|
78 |
+
},
|
79 |
+
'terminate_episode': i == len(eef_poses) - 1
|
80 |
+
}
|
81 |
+
steps.append(step)
|
82 |
+
|
83 |
+
|
84 |
+
steps_dataset = tf.data.Dataset.from_generator(
|
85 |
+
lambda: iter(steps),
|
86 |
+
output_signature={
|
87 |
+
'task_demo_id': tf.TensorSpec(shape=(), dtype=tf.string),
|
88 |
+
'eef_poses': tf.TensorSpec(shape=(7,), dtype=tf.float32),
|
89 |
+
'joint_states_arm': tf.TensorSpec(shape=(27,), dtype=tf.float32),
|
90 |
+
'joint_states_gripper': tf.TensorSpec(shape=(3,), dtype=tf.float32),
|
91 |
+
'user_control': tf.TensorSpec(shape=(22,), dtype=tf.float32),
|
92 |
+
'observation':{
|
93 |
+
'top_depth_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
|
94 |
+
'top_rgb_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
|
95 |
+
'front_rgb_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
|
96 |
+
},
|
97 |
+
'terminate_episode': tf.TensorSpec(shape=(), dtype=tf.bool),
|
98 |
+
}
|
99 |
+
)
|
100 |
+
|
101 |
+
yield {'steps': steps_dataset}
|
102 |
+
|
103 |
+
def load_dataset():
|
104 |
+
filename = './data/datasets/roboturk/SawyerObjectSearch_aligned_dataset.hdf5'
|
105 |
+
total_samples = count_total_samples(filename)
|
106 |
+
dataset = tf.data.Dataset.from_generator(
|
107 |
+
lambda: dataset_generator(filename, total_samples),
|
108 |
+
output_signature={
|
109 |
+
'steps': tf.data.DatasetSpec(
|
110 |
+
element_spec={
|
111 |
+
'task_demo_id': tf.TensorSpec(shape=(), dtype=tf.string),
|
112 |
+
'eef_poses': tf.TensorSpec(shape=(7,), dtype=tf.float32),
|
113 |
+
'joint_states_arm': tf.TensorSpec(shape=(27,), dtype=tf.float32),
|
114 |
+
'joint_states_gripper': tf.TensorSpec(shape=(3,), dtype=tf.float32),
|
115 |
+
'user_control': tf.TensorSpec(shape=(22,), dtype=tf.float32),
|
116 |
+
'observation':{
|
117 |
+
'top_depth_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
|
118 |
+
'top_rgb_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
|
119 |
+
'front_rgb_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
|
120 |
+
},
|
121 |
+
'terminate_episode': tf.TensorSpec(shape=(), dtype = tf.bool),
|
122 |
+
}
|
123 |
+
)
|
124 |
+
}
|
125 |
+
)
|
126 |
+
return dataset
|
127 |
+
|
128 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
129 |
+
"""
|
130 |
+
Convert terminate action to a boolean, where True means terminate.
|
131 |
+
"""
|
132 |
+
return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
|
133 |
+
|
134 |
+
|
135 |
+
def process_step(step: dict) -> dict:
|
136 |
+
"""
|
137 |
+
Unify the action format and clean the task instruction.
|
138 |
+
|
139 |
+
DO NOT use python list, use tf.TensorArray instead.
|
140 |
+
"""
|
141 |
+
# Convert raw action to our action
|
142 |
+
step['action'] = {}
|
143 |
+
action = step['action']
|
144 |
+
action['terminate'] = step['terminate_episode']
|
145 |
+
|
146 |
+
eef_delta_pos = step['eef_poses'][:3]
|
147 |
+
eef_ang = step['eef_poses'][3:]
|
148 |
+
|
149 |
+
# No base found
|
150 |
+
# Concatenate the action
|
151 |
+
arm_action = tf.concat([eef_delta_pos, eef_ang], axis=0)
|
152 |
+
action['arm_concat'] = arm_action
|
153 |
+
|
154 |
+
# Write the action format
|
155 |
+
action['format'] = tf.constant(
|
156 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w")
|
157 |
+
|
158 |
+
# No state found
|
159 |
+
state = step['observation']
|
160 |
+
# joint_states_arm: dataset of (num_timestamps, 27) shape where each of the 9 joints is represented by the JointState message
|
161 |
+
# (the nine joints are in order by their ROSBAG names: ['head_pan', 'right_j0', 'right_j1', 'right_j2', 'right_j3', 'right_j4', 'right_j5', 'right_j6', 'torso_t0']. For the most part, head_pan and torso should be zeros)
|
162 |
+
# [0] the position of the first joint (rad or m)
|
163 |
+
# [1] the velocity of the first joint (rad/s or m/s)
|
164 |
+
# [2] the effort that is applied in the first joint
|
165 |
+
# [3] the position of the second joint...
|
166 |
+
joint_states_arm = step['joint_states_arm']
|
167 |
+
joint_pos = joint_states_arm[3:24:3]
|
168 |
+
joint_vel = joint_states_arm[4:25:3]
|
169 |
+
# joint_states_gripper: dataset of (num_timestamps, 3) shape
|
170 |
+
# [0] the position of the gripper (rad or m)
|
171 |
+
# [1] the velocity of the gripper (rad/s or m/s)
|
172 |
+
# [2] the effort that is applied in the gripper
|
173 |
+
joint_states_gripper = step['joint_states_gripper']
|
174 |
+
gripper_pos = joint_states_gripper[:1]
|
175 |
+
# remove gripper_vel due to they are all zeros
|
176 |
+
# gripper_vel = joint_states_gripper[1:2]
|
177 |
+
# Concatenate the state
|
178 |
+
# state['arm_concat'] = tf.concat([joint_pos,joint_vel,gripper_pos,gripper_vel], axis=0)
|
179 |
+
state['arm_concat'] = tf.concat([joint_pos,joint_vel,gripper_pos], axis=0)
|
180 |
+
# Write the state format
|
181 |
+
state['format'] = tf.constant(
|
182 |
+
"arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,arm_joint_0_vel,arm_joint_1_vel,arm_joint_2_vel,arm_joint_3_vel,arm_joint_4_vel,arm_joint_5_vel,arm_joint_6_vel,gripper_joint_0_pos")
|
183 |
+
|
184 |
+
|
185 |
+
# Clean the task instruction
|
186 |
+
# Define the replacements (old, new) as a dictionary
|
187 |
+
replacements = {
|
188 |
+
'_': ' ',
|
189 |
+
'1f': ' ',
|
190 |
+
'4f': ' ',
|
191 |
+
'-': ' ',
|
192 |
+
'50': ' ',
|
193 |
+
'55': ' ',
|
194 |
+
'56': ' ',
|
195 |
+
|
196 |
+
}
|
197 |
+
# copied from openxembod
|
198 |
+
instr = b'create tower'
|
199 |
+
instr = clean_task_instruction(instr, replacements)
|
200 |
+
step['observation']['natural_language_instruction'] = instr
|
201 |
+
|
202 |
+
return step
|
203 |
+
|
204 |
+
|
205 |
+
if __name__ == "__main__":
|
206 |
+
import tensorflow_datasets as tfds
|
207 |
+
from data.utils import dataset_to_path
|
208 |
+
|
209 |
+
DATASET_DIR = '/cephfs-thu/gsm_data/openx_embod'
|
210 |
+
DATASET_NAME = 'roboturk_real_laundrylayout'
|
211 |
+
# Load the dataset
|
212 |
+
dataset = load_dataset()
|
213 |
+
|
214 |
+
# save_dir = os.path.join(DATASET_DIR, DATASET_NAME)
|
215 |
+
# if not os.path.exists(save_dir):
|
216 |
+
# os.makedirs(save_dir)
|
217 |
+
# tf.data.experimental.save(dataset, save_dir)
|
data/preprocess_scripts/roboturk_real_towercreation.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import tensorflow_datasets as tfds
|
3 |
+
from data.utils import clean_task_instruction, quaternion_to_euler
|
4 |
+
import tensorflow as tf
|
5 |
+
import h5py
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
import os
|
9 |
+
import imageio
|
10 |
+
import concurrent.futures
|
11 |
+
|
12 |
+
def get_frames(file_path):
|
13 |
+
if not os.path.exists(file_path) or not os.path.isfile(file_path) or not file_path.endswith('.mp4'):
|
14 |
+
return []
|
15 |
+
frames = []
|
16 |
+
with imageio.get_reader(file_path, 'ffmpeg') as reader:
|
17 |
+
for frame in reader:
|
18 |
+
frame = np.array(frame, dtype=np.uint8)
|
19 |
+
frames.append(frame)
|
20 |
+
return frames
|
21 |
+
|
22 |
+
def parallel_get_frames(paths):
|
23 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
24 |
+
future_to_path = {executor.submit(get_frames, path): path for path in paths}
|
25 |
+
return [future.result() for future in concurrent.futures.as_completed(future_to_path)]
|
26 |
+
|
27 |
+
def count_total_samples(filename):
|
28 |
+
total_samples = 0
|
29 |
+
with h5py.File(filename, 'r') as f:
|
30 |
+
data = f['data']
|
31 |
+
for user_key in data.keys():
|
32 |
+
user = data[user_key]
|
33 |
+
for demo_key in user.keys():
|
34 |
+
total_samples += 1
|
35 |
+
return total_samples
|
36 |
+
|
37 |
+
def dataset_generator(filename, total_samples):
|
38 |
+
with h5py.File(filename, 'r') as f:
|
39 |
+
data = f['data']
|
40 |
+
for user_key in data.keys():
|
41 |
+
user = data[user_key]
|
42 |
+
for demo_key in user.keys():
|
43 |
+
demo = user[demo_key]
|
44 |
+
robot_observation = demo['robot_observation']
|
45 |
+
user_control = demo['user_control']
|
46 |
+
|
47 |
+
eef_poses = robot_observation['eef_poses']
|
48 |
+
joint_states_arm = robot_observation['joint_states_arm']
|
49 |
+
joint_states_gripper = robot_observation['joint_states_gripper']
|
50 |
+
user_control_data = user_control['user_control']
|
51 |
+
|
52 |
+
attrs = dict(demo.attrs)
|
53 |
+
top_depth_video_file = attrs['top_depth_video_file']
|
54 |
+
top_rgb_video_file = attrs['top_rgb_video_file']
|
55 |
+
front_rgb_video_file = attrs['front_rgb_video_file']
|
56 |
+
|
57 |
+
video_root_path = './data/datasets/roboturk/'
|
58 |
+
top_depth_frames = get_frames(os.path.join(video_root_path, top_depth_video_file))
|
59 |
+
top_rgb_frames = get_frames(os.path.join(video_root_path, top_rgb_video_file))
|
60 |
+
front_rgb_frames = get_frames(os.path.join(video_root_path, front_rgb_video_file))
|
61 |
+
|
62 |
+
if len(top_rgb_frames) == 0 or len(front_rgb_frames) == 0:
|
63 |
+
continue
|
64 |
+
# video_root_path = '/cephfs-thu/gsm_data/robotruck'
|
65 |
+
# video_paths = [
|
66 |
+
# os.path.join(video_root_path, attrs['top_depth_video_file']),
|
67 |
+
# os.path.join(video_root_path, attrs['top_rgb_video_file']),
|
68 |
+
# os.path.join(video_root_path, attrs['front_rgb_video_file'])
|
69 |
+
# ]
|
70 |
+
# top_depth_frames, top_rgb_frames, front_rgb_frames = parallel_get_frames(video_paths)
|
71 |
+
|
72 |
+
steps = []
|
73 |
+
for i in range(len(eef_poses)):
|
74 |
+
task_demo_id = f"SawyerTowerCreation_{demo_key}_{i}"
|
75 |
+
step = {
|
76 |
+
'task_demo_id': task_demo_id,
|
77 |
+
'eef_poses': eef_poses[i],
|
78 |
+
'joint_states_arm': joint_states_arm[i],
|
79 |
+
'joint_states_gripper': joint_states_gripper[i],
|
80 |
+
'user_control': user_control_data[i] if user_control_data.shape[0] > 0 else np.zeros(22),
|
81 |
+
'observation':{
|
82 |
+
'top_depth_frame': top_depth_frames[i] if i < len(top_depth_frames) else np.zeros((0,0, 3), dtype=np.uint8),
|
83 |
+
'top_rgb_frame': top_rgb_frames[i] if i < len(top_rgb_frames) else np.zeros((0, 0, 3), dtype=np.uint8),
|
84 |
+
'front_rgb_frame': front_rgb_frames[i] if i < len(front_rgb_frames) else np.zeros((0, 0, 3), dtype=np.uint8),
|
85 |
+
},
|
86 |
+
'terminate_episode': i == len(eef_poses) - 1
|
87 |
+
}
|
88 |
+
steps.append(step)
|
89 |
+
|
90 |
+
|
91 |
+
steps_dataset = tf.data.Dataset.from_generator(
|
92 |
+
lambda: iter(steps),
|
93 |
+
output_signature={
|
94 |
+
'task_demo_id': tf.TensorSpec(shape=(), dtype=tf.string),
|
95 |
+
'eef_poses': tf.TensorSpec(shape=(7,), dtype=tf.float32),
|
96 |
+
'joint_states_arm': tf.TensorSpec(shape=(27,), dtype=tf.float32),
|
97 |
+
'joint_states_gripper': tf.TensorSpec(shape=(3,), dtype=tf.float32),
|
98 |
+
'user_control': tf.TensorSpec(shape=(22,), dtype=tf.float32),
|
99 |
+
'observation':{
|
100 |
+
'top_depth_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
|
101 |
+
'top_rgb_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
|
102 |
+
'front_rgb_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
|
103 |
+
},
|
104 |
+
'terminate_episode': tf.TensorSpec(shape=(), dtype=tf.bool),
|
105 |
+
}
|
106 |
+
)
|
107 |
+
|
108 |
+
yield {'steps': steps_dataset}
|
109 |
+
|
110 |
+
def load_dataset():
|
111 |
+
filename = './data/datasets/roboturk/SawyerTowerCreation_aligned_dataset.hdf5'
|
112 |
+
total_samples = count_total_samples(filename)
|
113 |
+
dataset = tf.data.Dataset.from_generator(
|
114 |
+
lambda: dataset_generator(filename, total_samples),
|
115 |
+
output_signature={
|
116 |
+
'steps': tf.data.DatasetSpec(
|
117 |
+
element_spec={
|
118 |
+
'task_demo_id': tf.TensorSpec(shape=(), dtype=tf.string),
|
119 |
+
'eef_poses': tf.TensorSpec(shape=(7,), dtype=tf.float32),
|
120 |
+
'joint_states_arm': tf.TensorSpec(shape=(27,), dtype=tf.float32),
|
121 |
+
'joint_states_gripper': tf.TensorSpec(shape=(3,), dtype=tf.float32),
|
122 |
+
'user_control': tf.TensorSpec(shape=(22,), dtype=tf.float32),
|
123 |
+
'observation':{
|
124 |
+
'top_depth_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
|
125 |
+
'top_rgb_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
|
126 |
+
'front_rgb_frame': tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
|
127 |
+
},
|
128 |
+
'terminate_episode': tf.TensorSpec(shape=(), dtype = tf.bool),
|
129 |
+
}
|
130 |
+
)
|
131 |
+
}
|
132 |
+
)
|
133 |
+
return dataset
|
134 |
+
|
135 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
136 |
+
"""
|
137 |
+
Convert terminate action to a boolean, where True means terminate.
|
138 |
+
"""
|
139 |
+
return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
|
140 |
+
|
141 |
+
|
142 |
+
def process_step(step: dict) -> dict:
|
143 |
+
"""
|
144 |
+
Unify the action format and clean the task instruction.
|
145 |
+
|
146 |
+
DO NOT use python list, use tf.TensorArray instead.
|
147 |
+
"""
|
148 |
+
# Convert raw action to our action
|
149 |
+
step['action'] = {}
|
150 |
+
action = step['action']
|
151 |
+
action['terminate'] = step['terminate_episode']
|
152 |
+
|
153 |
+
eef_delta_pos = step['eef_poses'][:3]
|
154 |
+
eef_ang = quaternion_to_euler(step['eef_poses'][3:])
|
155 |
+
|
156 |
+
# No base found
|
157 |
+
# Concatenate the action
|
158 |
+
arm_action = tf.concat([eef_delta_pos, eef_ang], axis=0)
|
159 |
+
action['arm_concat'] = arm_action
|
160 |
+
|
161 |
+
# Write the action format
|
162 |
+
action['format'] = tf.constant(
|
163 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_roll,eef_delta_angle_pitch,eef_delta_angle_yaw")
|
164 |
+
|
165 |
+
# No state found
|
166 |
+
state = step['observation']
|
167 |
+
# joint_states_arm: dataset of (num_timestamps, 27) shape where each of the 9 joints is represented by the JointState message
|
168 |
+
# (the nine joints are in order by their ROSBAG names: ['head_pan', 'right_j0', 'right_j1', 'right_j2', 'right_j3', 'right_j4', 'right_j5', 'right_j6', 'torso_t0']. For the most part, head_pan and torso should be zeros)
|
169 |
+
# [0] the position of the first joint (rad or m)
|
170 |
+
# [1] the velocity of the first joint (rad/s or m/s)
|
171 |
+
# [2] the effort that is applied in the first joint
|
172 |
+
# [3] the position of the second joint...
|
173 |
+
joint_states_arm = step['joint_states_arm']
|
174 |
+
joint_pos = joint_states_arm[3:24:3]
|
175 |
+
joint_vel = joint_states_arm[4:25:3]
|
176 |
+
# joint_states_gripper: dataset of (num_timestamps, 3) shape
|
177 |
+
# [0] the position of the gripper (rad or m)
|
178 |
+
# [1] the velocity of the gripper (rad/s or m/s)
|
179 |
+
# [2] the effort that is applied in the gripper
|
180 |
+
joint_states_gripper = step['joint_states_gripper']
|
181 |
+
gripper_pos = joint_states_gripper[:1]
|
182 |
+
# remove gripper_vel due to they are all zeros
|
183 |
+
# gripper_vel = joint_states_gripper[1:2]
|
184 |
+
# Concatenate the state
|
185 |
+
# state['arm_concat'] = tf.concat([joint_pos,joint_vel,gripper_pos,gripper_vel], axis=0)
|
186 |
+
state['arm_concat'] = tf.concat([joint_pos,joint_vel,gripper_pos], axis=0)
|
187 |
+
# Write the state format
|
188 |
+
state['format'] = tf.constant(
|
189 |
+
"arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,arm_joint_0_vel,arm_joint_1_vel,arm_joint_2_vel,arm_joint_3_vel,arm_joint_4_vel,arm_joint_5_vel,arm_joint_6_vel,gripper_joint_0_pos")
|
190 |
+
|
191 |
+
# Clean the task instruction
|
192 |
+
# Define the replacements (old, new) as a dictionary
|
193 |
+
replacements = {
|
194 |
+
'_': ' ',
|
195 |
+
'1f': ' ',
|
196 |
+
'4f': ' ',
|
197 |
+
'-': ' ',
|
198 |
+
'50': ' ',
|
199 |
+
'55': ' ',
|
200 |
+
'56': ' ',
|
201 |
+
|
202 |
+
}
|
203 |
+
# copied from openxembod
|
204 |
+
instr = b'create tower'
|
205 |
+
instr = clean_task_instruction(instr, replacements)
|
206 |
+
step['observation']['natural_language_instruction'] = instr
|
207 |
+
|
208 |
+
return step
|
209 |
+
|
210 |
+
|
211 |
+
if __name__ == "__main__":
|
212 |
+
import tensorflow_datasets as tfds
|
213 |
+
from data.utils import dataset_to_path
|
214 |
+
|
215 |
+
DATASET_DIR = '/cephfs-thu/gsm_data/openx_embod'
|
216 |
+
DATASET_NAME = 'roboturk_real_laundrylayout'
|
217 |
+
# Load the dataset
|
218 |
+
dataset = load_dataset()
|
219 |
+
|
220 |
+
# save_dir = os.path.join(DATASET_DIR, DATASET_NAME)
|
221 |
+
# if not os.path.exists(save_dir):
|
222 |
+
# os.makedirs(save_dir)
|
223 |
+
# tf.data.experimental.save(dataset, save_dir)
|
data/preprocess_scripts/stanford_hydra_dataset_converted_externally_to_rlds.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, euler_to_quaternion, euler_to_rotation_matrix,\
|
4 |
+
rotation_matrix_to_ortho6d
|
5 |
+
|
6 |
+
|
7 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
8 |
+
"""
|
9 |
+
Convert terminate action to a boolean, where True means terminate.
|
10 |
+
"""
|
11 |
+
return tf.reduce_all(tf.equal(terminate_act, tf.constant([1, 0, 0], dtype=tf.int32)))
|
12 |
+
|
13 |
+
|
14 |
+
def process_step(step: dict) -> dict:
|
15 |
+
"""
|
16 |
+
Unify the action format and clean the task instruction.
|
17 |
+
|
18 |
+
DO NOT use python list, use tf.TensorArray instead.
|
19 |
+
"""
|
20 |
+
# Convert raw action to our action
|
21 |
+
action = step['action']
|
22 |
+
eef_delta_pos = action[:3]
|
23 |
+
eef_ang = action[3:6]
|
24 |
+
eef_ang = euler_to_quaternion(eef_ang)
|
25 |
+
grip_open = tf.expand_dims(1 - action[6], axis=0)
|
26 |
+
|
27 |
+
# Concatenate the action
|
28 |
+
# action['arm_concat'] = tf.concat([eef_delta_pos, eef_ang, grip_open], axis=0)
|
29 |
+
step['action'] = {}
|
30 |
+
action = step['action']
|
31 |
+
action['arm_concat'] = tf.concat([eef_delta_pos, eef_ang, grip_open], axis=0)
|
32 |
+
action['terminate'] = step['is_terminal']
|
33 |
+
|
34 |
+
# Write the action format
|
35 |
+
action['format'] = tf.constant(
|
36 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
|
37 |
+
|
38 |
+
# Convert raw state to our state
|
39 |
+
state = step['observation']
|
40 |
+
state_vec = state['state']
|
41 |
+
# Robot state, consists of [3x EEF position,4x EEF orientation in quaternion,3x EEF orientation in euler angle,7x robot joint angles, 7x robot joint velocities,3x gripper state.
|
42 |
+
arm_joint_pos = state_vec[10:17]
|
43 |
+
arm_joint_vel = state_vec[17:24]
|
44 |
+
eef_pos = state_vec[:3]
|
45 |
+
eef_ang = state_vec[7:10]
|
46 |
+
eef_ang = euler_to_rotation_matrix(eef_ang)
|
47 |
+
eef_ang = rotation_matrix_to_ortho6d(eef_ang)
|
48 |
+
# Rescale gripper width to [0, 1]
|
49 |
+
grip_joint_pos = tf.concat([
|
50 |
+
state_vec[24:25] * 12.324, state_vec[25:27]
|
51 |
+
], axis=0)
|
52 |
+
|
53 |
+
# Concatenate the state
|
54 |
+
state['arm_concat'] = tf.concat([arm_joint_pos, grip_joint_pos, arm_joint_vel, eef_pos, eef_ang], axis=0)
|
55 |
+
|
56 |
+
# Write the state format
|
57 |
+
state['format'] = tf.constant(
|
58 |
+
"arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,gripper_joint_0_pos,gripper_joint_1_pos,gripper_joint_2_pos,arm_joint_0_vel,arm_joint_1_vel,arm_joint_2_vel,arm_joint_3_vel,arm_joint_4_vel,arm_joint_5_vel,arm_joint_6_vel,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
|
59 |
+
|
60 |
+
# Clean the task instruction
|
61 |
+
# Define the replacements (old, new) as a dictionary
|
62 |
+
replacements = {
|
63 |
+
'_': ' ',
|
64 |
+
'1f': ' ',
|
65 |
+
'4f': ' ',
|
66 |
+
'-': ' ',
|
67 |
+
'50': ' ',
|
68 |
+
'55': ' ',
|
69 |
+
'56': ' ',
|
70 |
+
|
71 |
+
}
|
72 |
+
instr = step['language_instruction']
|
73 |
+
instr = clean_task_instruction(instr, replacements)
|
74 |
+
step['observation']['natural_language_instruction'] = instr
|
75 |
+
|
76 |
+
return step
|
77 |
+
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
import tensorflow_datasets as tfds
|
81 |
+
from data.utils import dataset_to_path
|
82 |
+
|
83 |
+
DATASET_DIR = 'data/datasets/openx_embod'
|
84 |
+
DATASET_NAME = 'fractal20220817_data'
|
85 |
+
# Load the dataset
|
86 |
+
dataset = tfds.builder_from_directory(
|
87 |
+
builder_dir=dataset_to_path(
|
88 |
+
DATASET_NAME, DATASET_DIR))
|
89 |
+
dataset = dataset.as_dataset(split='all')
|
90 |
+
|
91 |
+
# Inspect the dataset
|
92 |
+
for episode in dataset:
|
93 |
+
for step in episode['steps']:
|
94 |
+
print(step)
|
data/preprocess_scripts/tokyo_u_lsmo_converted_externally_to_rlds.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, euler_to_quaternion, euler_to_rotation_matrix, \
|
4 |
+
rotation_matrix_to_ortho6d
|
5 |
+
|
6 |
+
|
7 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
8 |
+
"""
|
9 |
+
Convert terminate action to a boolean, where True means terminate.
|
10 |
+
"""
|
11 |
+
return tf.reduce_all(tf.equal(terminate_act, tf.constant([1, 0, 0], dtype=tf.int32)))
|
12 |
+
|
13 |
+
|
14 |
+
def process_step(step: dict) -> dict:
|
15 |
+
"""
|
16 |
+
Unify the action format and clean the task instruction.
|
17 |
+
|
18 |
+
DO NOT use python list, use tf.TensorArray instead.
|
19 |
+
"""
|
20 |
+
# Convert raw action to our action
|
21 |
+
action = step['action']
|
22 |
+
# Robot action, consists of [3x endeffector position, 3x euler angles,1x gripper action].
|
23 |
+
eef_delta_pos = action[:3]
|
24 |
+
eef_ang = action[3:6]
|
25 |
+
eef_ang = euler_to_quaternion(eef_ang)
|
26 |
+
grip_open = tf.expand_dims(1 - action[6], axis=0)
|
27 |
+
|
28 |
+
# Concatenate the action
|
29 |
+
step['action'] = {}
|
30 |
+
action = step['action']
|
31 |
+
action['arm_concat'] = tf.concat([eef_delta_pos, eef_ang, grip_open], axis=0)
|
32 |
+
action['terminate'] = step['is_terminal']
|
33 |
+
|
34 |
+
# Write the action format
|
35 |
+
action['format'] = tf.constant(
|
36 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
|
37 |
+
|
38 |
+
# Convert raw state to our state
|
39 |
+
state = step['observation']
|
40 |
+
state_vec = state['state']
|
41 |
+
# Robot state, consists of [3x endeffector position, 3x euler angles,6x robot joint angles, 1x gripper position].
|
42 |
+
eef_pos = state_vec[:3]
|
43 |
+
eef_ang = state_vec[3:6]
|
44 |
+
eef_ang = euler_to_rotation_matrix(eef_ang)
|
45 |
+
eef_ang = rotation_matrix_to_ortho6d(eef_ang)
|
46 |
+
arm_joint_ang = state_vec[6:12]
|
47 |
+
grip_joint_pos = 1 - state_vec[12:13]
|
48 |
+
|
49 |
+
# Concatenate the state
|
50 |
+
state['arm_concat'] = tf.concat([arm_joint_ang, grip_joint_pos, eef_pos, eef_ang], axis=0)
|
51 |
+
|
52 |
+
# Write the state format
|
53 |
+
state['format'] = tf.constant(
|
54 |
+
"arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,gripper_joint_0_pos,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
|
55 |
+
|
56 |
+
# Clean the task instruction
|
57 |
+
# Define the replacements (old, new) as a dictionary
|
58 |
+
replacements = {
|
59 |
+
'_': ' ',
|
60 |
+
'1f': ' ',
|
61 |
+
'4f': ' ',
|
62 |
+
'-': ' ',
|
63 |
+
'50': ' ',
|
64 |
+
'55': ' ',
|
65 |
+
'56': ' ',
|
66 |
+
|
67 |
+
}
|
68 |
+
instr = step['language_instruction']
|
69 |
+
instr = clean_task_instruction(instr, replacements)
|
70 |
+
step['observation']['natural_language_instruction'] = instr
|
71 |
+
|
72 |
+
return step
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
import tensorflow_datasets as tfds
|
77 |
+
from data.utils import dataset_to_path
|
78 |
+
|
79 |
+
DATASET_DIR = 'data/datasets/openx_embod'
|
80 |
+
DATASET_NAME = 'fractal20220817_data'
|
81 |
+
# Load the dataset
|
82 |
+
dataset = tfds.builder_from_directory(
|
83 |
+
builder_dir=dataset_to_path(
|
84 |
+
DATASET_NAME, DATASET_DIR))
|
85 |
+
dataset = dataset.as_dataset(split='all')
|
86 |
+
|
87 |
+
# Inspect the dataset
|
88 |
+
for episode in dataset:
|
89 |
+
for step in episode['steps']:
|
90 |
+
print(step)
|
data/preprocess_scripts/utokyo_pr2_opening_fridge_converted_externally_to_rlds.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, euler_to_quaternion, euler_to_rotation_matrix, \
|
4 |
+
rotation_matrix_to_ortho6d
|
5 |
+
|
6 |
+
|
7 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
8 |
+
"""
|
9 |
+
Convert terminate action to a boolean, where True means terminate.
|
10 |
+
"""
|
11 |
+
return tf.reduce_all(tf.equal(terminate_act, tf.constant([1, 0, 0], dtype=tf.float32)))
|
12 |
+
|
13 |
+
|
14 |
+
def process_step(step: dict) -> dict:
|
15 |
+
"""
|
16 |
+
Unify the action format and clean the task instruction.
|
17 |
+
|
18 |
+
DO NOT use python list, use tf.TensorArray instead.
|
19 |
+
"""
|
20 |
+
|
21 |
+
# Convert raw action to our action
|
22 |
+
action = step['action']
|
23 |
+
# Robot action, consists of [3x end effector pos, 3x robot rpy angles, 1x gripper open/close command, 1x terminal action].
|
24 |
+
eef_delta_pos = action[:3]/1000 # change from mm to m
|
25 |
+
eef_ang = action[3:6]
|
26 |
+
eef_ang = euler_to_quaternion(eef_ang)
|
27 |
+
grip_open = tf.expand_dims(1 - action[6], axis=0)
|
28 |
+
|
29 |
+
# Concatenate the action
|
30 |
+
step['action'] = {}
|
31 |
+
action = step['action']
|
32 |
+
action['arm_concat'] = tf.concat([eef_delta_pos, eef_ang, grip_open], axis=0)
|
33 |
+
|
34 |
+
action['terminate'] = step['is_terminal']
|
35 |
+
# Write the action format
|
36 |
+
action['format'] = tf.constant(
|
37 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
|
38 |
+
|
39 |
+
|
40 |
+
# Convert raw state to our state
|
41 |
+
state = step['observation']['state']
|
42 |
+
# Robot state, consists of [3x end effector pos, 3x robot rpy angles, 1x gripper position].
|
43 |
+
gripper_pos = state[:3]/1000 # change from mm to m
|
44 |
+
gripper_ang = state[3:6]
|
45 |
+
gripper_ang = euler_to_rotation_matrix(gripper_ang)
|
46 |
+
gripper_ang = rotation_matrix_to_ortho6d(gripper_ang)
|
47 |
+
gripper_open = state[6:7]/1000 * 11.54 # rescale to [0, 1]
|
48 |
+
|
49 |
+
|
50 |
+
# Concatenate the state
|
51 |
+
state = step['observation']
|
52 |
+
state['arm_concat'] = tf.concat([gripper_pos, gripper_ang, gripper_open], axis=0)
|
53 |
+
|
54 |
+
# Write the state format
|
55 |
+
state['format'] = tf.constant(
|
56 |
+
"eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5,gripper_joint_0_pos")
|
57 |
+
|
58 |
+
# Clean the task instruction
|
59 |
+
# Define the replacements (old, new) as a dictionary
|
60 |
+
replacements = {
|
61 |
+
'_': ' ',
|
62 |
+
'1f': ' ',
|
63 |
+
'4f': ' ',
|
64 |
+
'-': ' ',
|
65 |
+
'50': ' ',
|
66 |
+
'55': ' ',
|
67 |
+
'56': ' ',
|
68 |
+
|
69 |
+
}
|
70 |
+
instr = step['language_instruction']
|
71 |
+
instr = clean_task_instruction(instr, replacements)
|
72 |
+
step['observation']['natural_language_instruction'] = instr
|
73 |
+
|
74 |
+
return step
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
import tensorflow_datasets as tfds
|
79 |
+
from data.utils import dataset_to_path
|
80 |
+
|
81 |
+
DATASET_DIR = 'data/datasets/openx_embod'
|
82 |
+
DATASET_NAME = 'fractal20220817_data'
|
83 |
+
# Load the dataset
|
84 |
+
dataset = tfds.builder_from_directory(
|
85 |
+
builder_dir=dataset_to_path(
|
86 |
+
DATASET_NAME, DATASET_DIR))
|
87 |
+
dataset = dataset.as_dataset(split='all')
|
88 |
+
|
89 |
+
# Inspect the dataset
|
90 |
+
for episode in dataset:
|
91 |
+
for step in episode['steps']:
|
92 |
+
print(step)
|
data/preprocess_scripts/utokyo_xarm_bimanual_converted_externally_to_rlds.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import tensorflow as tf
|
3 |
+
|
4 |
+
from data.utils import clean_task_instruction, euler_to_rotation_matrix, rotation_matrix_to_ortho6d
|
5 |
+
|
6 |
+
|
7 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
8 |
+
"""
|
9 |
+
Convert terminate action to a boolean, where True means terminate.
|
10 |
+
"""
|
11 |
+
return tf.reduce_all(tf.equal(terminate_act, tf.constant([1, 0, 0], dtype=tf.int32)))
|
12 |
+
|
13 |
+
|
14 |
+
def process_step(step: dict) -> dict:
|
15 |
+
"""
|
16 |
+
Unify the action format and clean the task instruction.
|
17 |
+
|
18 |
+
DO NOT use python list, use tf.TensorArray instead.
|
19 |
+
"""
|
20 |
+
# Convert raw action to our action
|
21 |
+
action = step['action']
|
22 |
+
|
23 |
+
# TODO
|
24 |
+
# action : Tensor (14,) [3x EEF position (L), 3x EEF orientation yaw/pitch/roll (L), 1x gripper open/close position (L), 3x EEF position (R), 3x EEF orientation yaw/pitch/roll (R), 1x gripper open/close position (R)].
|
25 |
+
|
26 |
+
eef_pos_left = action[0:3]
|
27 |
+
eef_angle_left = tf.gather(action[3:6], [2, 1, 0])
|
28 |
+
eef_angle_left = euler_to_rotation_matrix(eef_angle_left)
|
29 |
+
eef_angle_left = rotation_matrix_to_ortho6d(eef_angle_left)
|
30 |
+
gripper_open_left = 1 - action[6:7]
|
31 |
+
eef_pos_right = action[7:10]
|
32 |
+
eef_angle_right = tf.gather(action[10:13], [2, 1, 0])
|
33 |
+
eef_angle_right = euler_to_rotation_matrix(eef_angle_right)
|
34 |
+
eef_angle_right = rotation_matrix_to_ortho6d(eef_angle_right)
|
35 |
+
gripper_open_right = 1 - action[13:14]
|
36 |
+
|
37 |
+
# Concatenate the action
|
38 |
+
step['action'] = {}
|
39 |
+
action = step['action']
|
40 |
+
|
41 |
+
# Concatenate the action
|
42 |
+
arm_action = tf.concat([eef_pos_left,eef_angle_left,gripper_open_left,eef_pos_right,eef_angle_right,gripper_open_right], axis=0)
|
43 |
+
action['arm_concat'] = arm_action
|
44 |
+
action['terminate'] = step['is_terminal']
|
45 |
+
|
46 |
+
# print("action len:", len(action['arm_concat']) + len(action['base_concat']))
|
47 |
+
|
48 |
+
action['format'] = tf.constant(
|
49 |
+
"left_eef_pos_x,left_eef_pos_y,left_eef_pos_z,left_eef_angle_0,left_eef_angle_1,left_eef_angle_2,left_eef_angle_3,left_eef_angle_4,left_eef_angle_5,left_gripper_open,right_eef_pos_x,right_eef_pos_y,right_eef_pos_z,right_eef_angle_0,right_eef_angle_1,right_eef_angle_2,right_eef_angle_3,right_eef_angle_4,right_eef_angle_5,right_gripper_open")
|
50 |
+
|
51 |
+
# action good for kuka same as example
|
52 |
+
|
53 |
+
# Convert raw state to our state
|
54 |
+
action = step['observation']['action_l']
|
55 |
+
# [3x EEF position, 3x EEF orientation yaw/pitch/roll, 1x gripper open/close position].
|
56 |
+
eef_pos_left = action[0:3]
|
57 |
+
eef_angle_left = tf.gather(action[3:6], [2, 1, 0])
|
58 |
+
eef_angle_left = euler_to_rotation_matrix(eef_angle_left)
|
59 |
+
eef_angle_left = rotation_matrix_to_ortho6d(eef_angle_left)
|
60 |
+
gripper_open_left = 1 - action[6:7]
|
61 |
+
|
62 |
+
action = step['observation']['action_r']
|
63 |
+
eef_pos_right = action[0:3]
|
64 |
+
eef_angle_right = tf.gather(action[3:6], [2, 1, 0])
|
65 |
+
eef_angle_right = euler_to_rotation_matrix(eef_angle_right)
|
66 |
+
eef_angle_right = rotation_matrix_to_ortho6d(eef_angle_right)
|
67 |
+
gripper_open_right = 1 - action[6:7]
|
68 |
+
|
69 |
+
# Write the state format TODO how to link 12 joint pos to 7 joint pos ??
|
70 |
+
state = step['observation']
|
71 |
+
# Concatenate the state
|
72 |
+
state['arm_concat'] = tf.concat([eef_pos_left,eef_angle_left,gripper_open_left,eef_pos_right,eef_angle_right,gripper_open_right], axis=0)
|
73 |
+
state['format'] = tf.constant(
|
74 |
+
"left_eef_pos_x,left_eef_pos_y,left_eef_pos_z,left_eef_angle_0,left_eef_angle_1,left_eef_angle_2,left_eef_angle_3,left_eef_angle_4,left_eef_angle_5,left_gripper_open,right_eef_pos_x,right_eef_pos_y,right_eef_pos_z,right_eef_angle_0,right_eef_angle_1,right_eef_angle_2,right_eef_angle_3,right_eef_angle_4,right_eef_angle_5,right_gripper_open")
|
75 |
+
|
76 |
+
# Clean the task instruction
|
77 |
+
# Define the replacements (old, new) as a dictionary
|
78 |
+
replacements = {
|
79 |
+
'_': ' ',
|
80 |
+
'1f': ' ',
|
81 |
+
'4f': ' ',
|
82 |
+
'-': ' ',
|
83 |
+
'50': ' ',
|
84 |
+
'55': ' ',
|
85 |
+
'56': ' ',
|
86 |
+
|
87 |
+
}
|
88 |
+
instr = step['language_instruction']
|
89 |
+
instr = clean_task_instruction(instr, replacements)
|
90 |
+
step['observation']['natural_language_instruction'] = instr
|
91 |
+
|
92 |
+
return step
|
93 |
+
|
94 |
+
|
95 |
+
if __name__ == "__main__":
|
96 |
+
import tensorflow_datasets as tfds
|
97 |
+
from data.utils import dataset_to_path
|
98 |
+
|
99 |
+
DATASET_DIR = 'data/datasets/openx_embod'
|
100 |
+
DATASET_NAME = 'utokyo_xarm_bimanual_converted_externally_to_rlds'
|
101 |
+
# Load the dataset
|
102 |
+
dataset = tfds.builder_from_directory(
|
103 |
+
builder_dir=dataset_to_path(
|
104 |
+
DATASET_NAME, DATASET_DIR))
|
105 |
+
dataset = dataset.as_dataset(split='all')
|
106 |
+
|
107 |
+
# with open('example.txt', 'w') as file:
|
108 |
+
# Inspect the dataset
|
109 |
+
|
110 |
+
episode_num = len(dataset)
|
111 |
+
print(f"episode_num: {episode_num}")
|
112 |
+
for episode in dataset.take(1):
|
113 |
+
# print("episode")
|
114 |
+
# print(list(episode.keys()))
|
115 |
+
for step in episode['steps']:
|
116 |
+
process_step(step)
|
117 |
+
break
|
data/preprocess_scripts/viola.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
|
3 |
+
from data.utils import clean_task_instruction, euler_to_quaternion, rotation_matrix_to_ortho6d
|
4 |
+
|
5 |
+
|
6 |
+
def terminate_act_to_bool(terminate_act: tf.Tensor) -> tf.Tensor:
|
7 |
+
"""
|
8 |
+
Convert terminate action to a boolean, where True means terminate.
|
9 |
+
"""
|
10 |
+
return tf.where(tf.equal(terminate_act, tf.constant(0.0, dtype=tf.float32)),tf.constant(False),tf.constant(True))
|
11 |
+
|
12 |
+
|
13 |
+
def process_step(step: dict) -> dict:
|
14 |
+
"""
|
15 |
+
Unify the action format and clean the task instruction.
|
16 |
+
|
17 |
+
DO NOT use python list, use tf.TensorArray instead.
|
18 |
+
"""
|
19 |
+
# Convert raw action to our action
|
20 |
+
action = step['action']
|
21 |
+
action['terminate'] = terminate_act_to_bool(action['terminate_episode'])
|
22 |
+
|
23 |
+
eef_delta_pos = action['world_vector']
|
24 |
+
eef_ang = action['rotation_delta']
|
25 |
+
eef_ang = euler_to_quaternion(eef_ang)
|
26 |
+
grip_open = tf.reshape(tf.where(action['gripper_closedness_action']<0,tf.constant(1.0),tf.constant(0.0)),(1,))
|
27 |
+
|
28 |
+
# No base found
|
29 |
+
|
30 |
+
# Concatenate the action
|
31 |
+
arm_action = tf.concat([eef_delta_pos, eef_ang, grip_open], axis=0)
|
32 |
+
action['arm_concat'] = arm_action
|
33 |
+
|
34 |
+
# Write the action format
|
35 |
+
action['format'] = tf.constant(
|
36 |
+
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z,eef_delta_angle_x,eef_delta_angle_y,eef_delta_angle_z,eef_delta_angle_w,gripper_open")
|
37 |
+
|
38 |
+
# Convert raw state to our state
|
39 |
+
state = step['observation']
|
40 |
+
joint_pos=state['joint_states']
|
41 |
+
grip_open=state['gripper_states'] * 12.905 # rescale to [0, 1]
|
42 |
+
state_ee=state['ee_states']
|
43 |
+
transform_matrix = tf.transpose(tf.reshape(state_ee, (4, 4)))
|
44 |
+
eef_pos = transform_matrix[:3, 3]
|
45 |
+
rotation_matrix = transform_matrix[:3, :3]
|
46 |
+
eef_ang = rotation_matrix_to_ortho6d(rotation_matrix)
|
47 |
+
|
48 |
+
# Concatenate the state
|
49 |
+
state['arm_concat'] = tf.concat([joint_pos,grip_open,eef_pos,eef_ang],axis=0)
|
50 |
+
|
51 |
+
# Write the state format
|
52 |
+
state['format'] = tf.constant(
|
53 |
+
"arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,arm_joint_6_pos,gripper_open,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5")
|
54 |
+
|
55 |
+
# Clean the task instruction
|
56 |
+
# Define the replacements (old, new) as a dictionary
|
57 |
+
replacements = {
|
58 |
+
'_': ' ',
|
59 |
+
'1f': ' ',
|
60 |
+
'4f': ' ',
|
61 |
+
'-': ' ',
|
62 |
+
'50': ' ',
|
63 |
+
'55': ' ',
|
64 |
+
'56': ' ',
|
65 |
+
|
66 |
+
}
|
67 |
+
instr = step['observation']['natural_language_instruction']
|
68 |
+
instr = clean_task_instruction(instr, replacements)
|
69 |
+
step['observation']['natural_language_instruction'] = instr
|
70 |
+
|
71 |
+
return step
|
72 |
+
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
import tensorflow_datasets as tfds
|
76 |
+
from data.utils import dataset_to_path
|
77 |
+
|
78 |
+
DATASET_DIR = 'data/datasets/openx_embod'
|
79 |
+
DATASET_NAME = 'viola'
|
80 |
+
# Load the dataset
|
81 |
+
dataset = tfds.builder_from_directory(
|
82 |
+
builder_dir=dataset_to_path(
|
83 |
+
DATASET_NAME, DATASET_DIR))
|
84 |
+
dataset = dataset.as_dataset(split='all')
|
85 |
+
|
86 |
+
# Inspect the dataset
|
87 |
+
for episode in dataset:
|
88 |
+
for step in episode['steps']:
|
89 |
+
print(step)
|
data/producer.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import argparse
|
6 |
+
import sys
|
7 |
+
import signal
|
8 |
+
import random
|
9 |
+
from multiprocessing import Process
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import tensorflow as tf
|
13 |
+
import yaml
|
14 |
+
|
15 |
+
from data.vla_dataset import VLADataset
|
16 |
+
from data.filelock import FileLock
|
17 |
+
|
18 |
+
|
19 |
+
# Producer does not need GPU
|
20 |
+
tf.config.set_visible_devices([], 'GPU')
|
21 |
+
|
22 |
+
# Read the config
|
23 |
+
with open('configs/base.yaml', 'r') as file:
|
24 |
+
config = yaml.safe_load(file)
|
25 |
+
# Load some constants from the config
|
26 |
+
BUF_PATH = config['dataset']['buf_path']
|
27 |
+
BUF_NUM_CHUNKS = config['dataset']['buf_num_chunks']
|
28 |
+
if BUF_NUM_CHUNKS < 1:
|
29 |
+
raise ValueError("Config `buf_num_chunks` must be at least 1.")
|
30 |
+
BUF_CHUNK_SIZE = config['dataset']['buf_chunk_size']
|
31 |
+
if BUF_CHUNK_SIZE < 1:
|
32 |
+
raise ValueError("Config `buf_chunk_size` must be at least 1.")
|
33 |
+
|
34 |
+
|
35 |
+
def get_dirty_item(chunk_dir):
|
36 |
+
"""
|
37 |
+
Get indexes of dirty items in a chunk.
|
38 |
+
"""
|
39 |
+
dirty_bit = read_dirty_bit(chunk_dir)
|
40 |
+
return np.where(dirty_bit)[0].tolist()
|
41 |
+
|
42 |
+
|
43 |
+
def get_clean_item(chunk_dir):
|
44 |
+
"""
|
45 |
+
Get indexes of clean items in a chunk.
|
46 |
+
"""
|
47 |
+
dirty_bit = read_dirty_bit(chunk_dir)
|
48 |
+
return np.where(1 - dirty_bit)[0].tolist()
|
49 |
+
|
50 |
+
|
51 |
+
def save_dirty_bit(chunk_dir, dirty_bit):
|
52 |
+
"""
|
53 |
+
Save the dirty bit to the chunk directory.
|
54 |
+
"""
|
55 |
+
time_stmp = time.time()
|
56 |
+
while time.time() - time_stmp < 10.0:
|
57 |
+
try:
|
58 |
+
file_path = os.path.join(chunk_dir, "dirty_bit")
|
59 |
+
lock = FileLock(file_path)
|
60 |
+
lock.acquire_write_lock()
|
61 |
+
with open(file_path, 'wb') as file:
|
62 |
+
file.write(dirty_bit.tobytes())
|
63 |
+
lock.release_lock()
|
64 |
+
return
|
65 |
+
except KeyboardInterrupt:
|
66 |
+
lock.release_lock()
|
67 |
+
raise KeyboardInterrupt
|
68 |
+
except BaseException:
|
69 |
+
lock.release_lock()
|
70 |
+
continue
|
71 |
+
# raise RuntimeError("Failed to save dirty bit.")
|
72 |
+
print("Failed to save dirty bit.")
|
73 |
+
|
74 |
+
|
75 |
+
def read_dirty_bit(chunk_dir):
|
76 |
+
"""
|
77 |
+
Read the dirty bit from the chunk directory.
|
78 |
+
"""
|
79 |
+
# If error occurs, retry
|
80 |
+
time_stmp = time.time()
|
81 |
+
while time.time() - time_stmp < 10.0:
|
82 |
+
try:
|
83 |
+
file_path = os.path.join(chunk_dir, "dirty_bit")
|
84 |
+
lock = FileLock(file_path)
|
85 |
+
lock.acquire_read_lock()
|
86 |
+
with open(file_path, 'rb') as file:
|
87 |
+
dirty_bit = np.frombuffer(file.read(), dtype=np.uint8).copy()
|
88 |
+
lock.release_lock()
|
89 |
+
assert len(dirty_bit) == BUF_CHUNK_SIZE
|
90 |
+
return dirty_bit
|
91 |
+
except KeyboardInterrupt:
|
92 |
+
lock.release_lock()
|
93 |
+
raise KeyboardInterrupt
|
94 |
+
except BaseException:
|
95 |
+
lock.release_lock()
|
96 |
+
continue
|
97 |
+
# If failed to read the dirty bit, return all ones for robustness
|
98 |
+
return np.ones(BUF_CHUNK_SIZE, dtype=np.uint8)
|
99 |
+
|
100 |
+
|
101 |
+
def save_sample(step_dict, chunk_dir, chunk_item_idx):
|
102 |
+
"""
|
103 |
+
Save a sample to the chunk directory.
|
104 |
+
"""
|
105 |
+
# Save the json content
|
106 |
+
time_stmp = time.time()
|
107 |
+
while time.time() - time_stmp < 10.0:
|
108 |
+
try:
|
109 |
+
locks = []
|
110 |
+
json_content = step_dict['json_content']
|
111 |
+
file_path = os.path.join(chunk_dir, f"json_content_{chunk_item_idx}.json")
|
112 |
+
lock = FileLock(file_path)
|
113 |
+
locks.append(lock)
|
114 |
+
lock.acquire_write_lock()
|
115 |
+
with open(file_path, 'w') as file:
|
116 |
+
json.dump(json_content, file, indent=4)
|
117 |
+
lock.release_lock()
|
118 |
+
# Save all other tensors in a npz
|
119 |
+
file_path = os.path.join(chunk_dir, f"sample_{chunk_item_idx}.npz")
|
120 |
+
lock = FileLock(file_path)
|
121 |
+
locks.append(lock)
|
122 |
+
lock.acquire_write_lock()
|
123 |
+
with open(file_path, 'wb') as file:
|
124 |
+
np.savez(
|
125 |
+
file,
|
126 |
+
step_id=step_dict['step_id'].numpy(),
|
127 |
+
state_chunk=step_dict['state_chunk'].numpy(),
|
128 |
+
state_chunk_time_mask=step_dict['state_chunk_time_mask'].numpy(),
|
129 |
+
action_chunk=step_dict['action_chunk'].numpy(),
|
130 |
+
action_chunk_time_mask=step_dict['action_chunk_time_mask'].numpy(),
|
131 |
+
state_vec_mask=step_dict['state_vec_mask'].numpy(),
|
132 |
+
past_frames_0=step_dict['past_frames_0'].numpy(),
|
133 |
+
past_frames_0_time_mask=step_dict['past_frames_0_time_mask'].numpy(),
|
134 |
+
past_frames_1=step_dict['past_frames_1'].numpy(),
|
135 |
+
past_frames_1_time_mask=step_dict['past_frames_1_time_mask'].numpy(),
|
136 |
+
past_frames_2=step_dict['past_frames_2'].numpy(),
|
137 |
+
past_frames_2_time_mask=step_dict['past_frames_2_time_mask'].numpy(),
|
138 |
+
past_frames_3=step_dict['past_frames_3'].numpy(),
|
139 |
+
past_frames_3_time_mask=step_dict['past_frames_3_time_mask'].numpy(),
|
140 |
+
state_std=step_dict['state_std'].numpy(),
|
141 |
+
state_mean=step_dict['state_mean'].numpy(),
|
142 |
+
state_norm=step_dict['state_norm'].numpy(),
|
143 |
+
)
|
144 |
+
lock.release_lock()
|
145 |
+
return
|
146 |
+
except KeyboardInterrupt:
|
147 |
+
for lock in locks:
|
148 |
+
lock.release_lock()
|
149 |
+
raise KeyboardInterrupt
|
150 |
+
except BaseException:
|
151 |
+
for lock in locks:
|
152 |
+
lock.release_lock()
|
153 |
+
continue
|
154 |
+
# raise RuntimeError("Failed to save sample.")
|
155 |
+
print("Failed to save sample.")
|
156 |
+
|
157 |
+
|
158 |
+
def run_producer(seed, num_workers, worker_id, fill_up, clean_dirty, dataset_type):
|
159 |
+
"""
|
160 |
+
Run the producer.
|
161 |
+
The producer will first fill up the buffer with samples.
|
162 |
+
Then it will keep replacing dirty samples
|
163 |
+
(i.e., samples that have been read by the consumer)
|
164 |
+
with new samples.
|
165 |
+
"""
|
166 |
+
vla_dataset = VLADataset(seed=seed, dataset_type=dataset_type)
|
167 |
+
chunk_start_idx = worker_id * BUF_NUM_CHUNKS // num_workers
|
168 |
+
chunk_end_idx = (worker_id + 1) * BUF_NUM_CHUNKS // num_workers
|
169 |
+
if fill_up:
|
170 |
+
print(f"Worker {worker_id}: Start filling up the buffer...")
|
171 |
+
elif clean_dirty:
|
172 |
+
# Only refresh the dirty bits
|
173 |
+
print(f"Worker {worker_id}: Start refreshing the dirty bits...")
|
174 |
+
for chunk_idx in range(chunk_start_idx, chunk_end_idx):
|
175 |
+
chunk_dir = os.path.join(BUF_PATH, f"chunk_{chunk_idx}")
|
176 |
+
dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
|
177 |
+
save_dirty_bit(chunk_dir, dirty_bit)
|
178 |
+
print(f"Worker {worker_id}: Refreshed the dirty bits.")
|
179 |
+
|
180 |
+
fill_chunk_idx = chunk_start_idx
|
181 |
+
fill_chunk_item_idx = 0
|
182 |
+
dirty_chunk_idx = chunk_start_idx
|
183 |
+
dirty_chunk_item_idxs = []
|
184 |
+
time_stmp = time.time()
|
185 |
+
for episode_steps in vla_dataset:
|
186 |
+
for step in episode_steps:
|
187 |
+
if fill_up and fill_chunk_idx < chunk_end_idx:
|
188 |
+
# Fill up the buffer
|
189 |
+
chunk_dir = os.path.join(BUF_PATH, f"chunk_{fill_chunk_idx}")
|
190 |
+
if fill_chunk_item_idx == 0:
|
191 |
+
# Create a new chunk
|
192 |
+
os.makedirs(chunk_dir, exist_ok=True)
|
193 |
+
# Write the dirty bit of size BUF_CHUNK_SIZE
|
194 |
+
dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
|
195 |
+
save_dirty_bit(chunk_dir, dirty_bit)
|
196 |
+
|
197 |
+
# Save the sample
|
198 |
+
save_sample(step, chunk_dir, fill_chunk_item_idx)
|
199 |
+
|
200 |
+
# print(f"Filled up chunk {fill_chunk_item_idx+1}/{BUF_CHUNK_SIZE} {fill_chunk_idx+1}/{BUF_NUM_CHUNKS}")
|
201 |
+
local_fill_chunk_idx = fill_chunk_idx - chunk_start_idx
|
202 |
+
local_num_chunks = chunk_end_idx - chunk_start_idx
|
203 |
+
if (local_fill_chunk_idx % 10 == 0 or local_fill_chunk_idx == local_num_chunks - 1) and fill_chunk_item_idx == 0:
|
204 |
+
print(f"Worker {worker_id}: Filled up chunk {local_fill_chunk_idx+1}/{local_num_chunks}")
|
205 |
+
fill_chunk_item_idx += 1
|
206 |
+
if fill_chunk_item_idx == BUF_CHUNK_SIZE:
|
207 |
+
fill_chunk_idx += 1
|
208 |
+
fill_chunk_item_idx = 0
|
209 |
+
if fill_chunk_idx == BUF_NUM_CHUNKS:
|
210 |
+
print(f"Worker {worker_id}: Buffer filled up. Start replacing dirty samples...")
|
211 |
+
|
212 |
+
else:
|
213 |
+
# Search for the dirty chunk to replace
|
214 |
+
while len(dirty_chunk_item_idxs) == 0:
|
215 |
+
dirty_chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}")
|
216 |
+
dirty_chunk_item_idxs = get_dirty_item(dirty_chunk_dir)
|
217 |
+
# Print the dirty ratio
|
218 |
+
if time.time() - time_stmp > 2.0:
|
219 |
+
dirty_ratio = len(dirty_chunk_item_idxs) / BUF_CHUNK_SIZE
|
220 |
+
print(f"Worker {worker_id}: Dirty Ratio for Chunk {dirty_chunk_idx}: {dirty_ratio:.2f}")
|
221 |
+
time_stmp = time.time()
|
222 |
+
|
223 |
+
if len(dirty_chunk_item_idxs) > 0:
|
224 |
+
# Lock the chunk
|
225 |
+
dirty_bit = np.ones(BUF_CHUNK_SIZE, dtype=np.uint8)
|
226 |
+
save_dirty_bit(dirty_chunk_dir, dirty_bit)
|
227 |
+
|
228 |
+
# Iterate over the chunks
|
229 |
+
dirty_chunk_idx += 1
|
230 |
+
if dirty_chunk_idx == chunk_end_idx:
|
231 |
+
dirty_chunk_idx = chunk_start_idx
|
232 |
+
|
233 |
+
# Replace the dirty item
|
234 |
+
dirty_item_idx = dirty_chunk_item_idxs.pop()
|
235 |
+
chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}")
|
236 |
+
# Save the sample
|
237 |
+
save_sample(step, chunk_dir, dirty_item_idx)
|
238 |
+
|
239 |
+
# If we have replaced all dirty items in the chunk
|
240 |
+
if len(dirty_chunk_item_idxs) == 0:
|
241 |
+
# Unlock the chunk
|
242 |
+
dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8)
|
243 |
+
save_dirty_bit(dirty_chunk_dir, dirty_bit)
|
244 |
+
print(f"Worker {worker_id}: Replaced dirty chunk {dirty_chunk_idx}.")
|
245 |
+
|
246 |
+
|
247 |
+
if __name__ == '__main__':
|
248 |
+
# Args: n_workers, fill_up
|
249 |
+
parser = argparse.ArgumentParser()
|
250 |
+
parser.add_argument('--n_workers', type=int, default=2, help="Number of parallel workers. It should be less than or equal to the number of chunks.")
|
251 |
+
parser.add_argument('--fill_up', action='store_true', help="Whether to fill up the buffer before replacing dirty samples.")
|
252 |
+
parser.add_argument('--clean_dirty', action='store_true', help="Whether to clean the dirty bits before replacing dirty samples. This option is ignored when `fill_up` is set.")
|
253 |
+
parser.add_argument('--seed', type=int, default=None, help="Random seed. If not set, the seed will be randomly generated.")
|
254 |
+
parser.add_argument('--dataset_type', type=str,
|
255 |
+
default="pretrain",
|
256 |
+
help="Whether to load the pretrain dataset or finetune dataset.")
|
257 |
+
|
258 |
+
# Run the producer
|
259 |
+
args = parser.parse_args()
|
260 |
+
if args.seed is not None:
|
261 |
+
print(f"Base seed: {args.seed}")
|
262 |
+
random.seed(args.seed)
|
263 |
+
|
264 |
+
processes = []
|
265 |
+
process_seeds = [random.randint(0, 2**32) for _ in range(args.n_workers)]
|
266 |
+
print(f"Process seeds: {process_seeds}")
|
267 |
+
def signal_handler(sig, frame):
|
268 |
+
print("Ctrl+C received. Terminating child processes...")
|
269 |
+
for p in processes:
|
270 |
+
p.terminate()
|
271 |
+
sys.exit(0)
|
272 |
+
signal.signal(signal.SIGINT, signal_handler)
|
273 |
+
for worker_id in range(args.n_workers):
|
274 |
+
p = Process(target=run_producer, args=(
|
275 |
+
process_seeds[worker_id], args.n_workers, worker_id, args.fill_up, args.clean_dirty, args.dataset_type))
|
276 |
+
p.start()
|
277 |
+
processes.append(p)
|
278 |
+
|
279 |
+
for p in processes:
|
280 |
+
p.join()
|
data/utils.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tensorflow as tf
|
2 |
+
import tensorflow_graphics.geometry.transformation.euler as tf_euler
|
3 |
+
import tensorflow_graphics.geometry.transformation.quaternion as tf_quat
|
4 |
+
import tensorflow_graphics.geometry.transformation.rotation_matrix_3d as tf_rotmat
|
5 |
+
|
6 |
+
|
7 |
+
def dataset_to_path(dataset_name: str, dir_name: str) -> str:
|
8 |
+
"""
|
9 |
+
Return the path to the dataset.
|
10 |
+
"""
|
11 |
+
if dataset_name == 'robo_net' or \
|
12 |
+
dataset_name == 'cmu_playing_with_food' or \
|
13 |
+
dataset_name == 'droid':
|
14 |
+
version = '1.0.0'
|
15 |
+
elif dataset_name == 'language_table' or \
|
16 |
+
dataset_name == 'fmb' or \
|
17 |
+
dataset_name == 'dobbe':
|
18 |
+
version = '0.0.1'
|
19 |
+
elif dataset_name == 'nyu_door_opening_surprising_effectiveness':
|
20 |
+
version = ''
|
21 |
+
elif dataset_name == 'cmu_play_fusion':
|
22 |
+
version=''
|
23 |
+
elif dataset_name=='berkeley_gnm_recon':
|
24 |
+
version=''
|
25 |
+
elif dataset_name=='vla_benchmark_ee':
|
26 |
+
version = '1.0.0'
|
27 |
+
else:
|
28 |
+
version = '1.0.0'
|
29 |
+
return f'{dir_name}/{dataset_name}/{version}'
|
30 |
+
|
31 |
+
|
32 |
+
def clean_task_instruction(
|
33 |
+
task_instruction: tf.Tensor, replacements: dict) -> tf.Tensor:
|
34 |
+
"""
|
35 |
+
Clean up the natural language task instruction.
|
36 |
+
"""
|
37 |
+
# Create a function that applies all replacements
|
38 |
+
def apply_replacements(tensor):
|
39 |
+
for old, new in replacements.items():
|
40 |
+
tensor = tf.strings.regex_replace(tensor, old, new)
|
41 |
+
return tensor
|
42 |
+
# Apply the replacements and strip leading and trailing spaces
|
43 |
+
cleaned_task_instruction = apply_replacements(task_instruction)
|
44 |
+
cleaned_task_instruction = tf.strings.strip(cleaned_task_instruction)
|
45 |
+
return cleaned_task_instruction
|
46 |
+
|
47 |
+
|
48 |
+
def quaternion_to_euler(quaternion: tf.Tensor) -> tf.Tensor:
|
49 |
+
"""
|
50 |
+
Convert a quaternion (x, y, z, w) to Euler angles (roll, pitch, yaw).
|
51 |
+
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
52 |
+
"""
|
53 |
+
# Normalize the quaternion
|
54 |
+
quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
|
55 |
+
return tf_euler.from_quaternion(quaternion)
|
56 |
+
|
57 |
+
|
58 |
+
def euler_to_quaternion(euler: tf.Tensor) -> tf.Tensor:
|
59 |
+
"""
|
60 |
+
Convert Euler angles (roll, pitch, yaw) to a quaternion (x, y, z, w).
|
61 |
+
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
62 |
+
"""
|
63 |
+
quaternion = tf_quat.from_euler(euler)
|
64 |
+
return tf.nn.l2_normalize(quaternion, axis=-1)
|
65 |
+
|
66 |
+
|
67 |
+
def rotation_matrix_to_euler(matrix: tf.Tensor) -> tf.Tensor:
|
68 |
+
"""
|
69 |
+
Convert a 3x3 rotation matrix to Euler angles (roll, pitch, yaw).
|
70 |
+
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
71 |
+
"""
|
72 |
+
return tf_euler.from_rotation_matrix(matrix)
|
73 |
+
|
74 |
+
|
75 |
+
def rotation_matrix_to_quaternion(matrix: tf.Tensor) -> tf.Tensor:
|
76 |
+
"""
|
77 |
+
Convert a 3x3 rotation matrix to a quaternion (x, y, z, w).
|
78 |
+
"""
|
79 |
+
quaternion = tf_quat.from_rotation_matrix(matrix)
|
80 |
+
return tf.nn.l2_normalize(quaternion, axis=-1)
|
81 |
+
|
82 |
+
|
83 |
+
def euler_to_rotation_matrix(euler: tf.Tensor) -> tf.Tensor:
|
84 |
+
"""
|
85 |
+
Convert Euler angles (roll, pitch, yaw) to a 3x3 rotation matrix.
|
86 |
+
The (roll, pitch, yaw) corresponds to `Rotation.as_euler("xyz")` convention.
|
87 |
+
"""
|
88 |
+
return tf_rotmat.from_euler(euler)
|
89 |
+
|
90 |
+
|
91 |
+
def quaternion_to_rotation_matrix(quaternion: tf.Tensor) -> tf.Tensor:
|
92 |
+
"""
|
93 |
+
Convert a quaternion (x, y, z, w) to a 3x3 rotation matrix.
|
94 |
+
"""
|
95 |
+
# Normalize the quaternion
|
96 |
+
quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
|
97 |
+
return tf_rotmat.from_quaternion(quaternion)
|
98 |
+
|
99 |
+
|
100 |
+
def quaternion_to_rotation_matrix_wo_static_check(quaternion: tf.Tensor) -> tf.Tensor:
|
101 |
+
"""
|
102 |
+
Convert a quaternion (x, y, z, w) to a 3x3 rotation matrix.
|
103 |
+
This function is used to make tensorflow happy.
|
104 |
+
"""
|
105 |
+
# Normalize the quaternion
|
106 |
+
quaternion = tf.nn.l2_normalize(quaternion, axis=-1)
|
107 |
+
|
108 |
+
x = quaternion[..., 0]
|
109 |
+
y = quaternion[..., 1]
|
110 |
+
z = quaternion[..., 2]
|
111 |
+
w = quaternion[..., 3]
|
112 |
+
|
113 |
+
tx = 2.0 * x
|
114 |
+
ty = 2.0 * y
|
115 |
+
tz = 2.0 * z
|
116 |
+
twx = tx * w
|
117 |
+
twy = ty * w
|
118 |
+
twz = tz * w
|
119 |
+
txx = tx * x
|
120 |
+
txy = ty * x
|
121 |
+
txz = tz * x
|
122 |
+
tyy = ty * y
|
123 |
+
tyz = tz * y
|
124 |
+
tzz = tz * z
|
125 |
+
matrix = tf.stack((1.0 - (tyy + tzz), txy - twz, txz + twy,
|
126 |
+
txy + twz, 1.0 - (txx + tzz), tyz - twx,
|
127 |
+
txz - twy, tyz + twx, 1.0 - (txx + tyy)),
|
128 |
+
axis=-1) # pyformat: disable
|
129 |
+
output_shape = tf.concat((tf.shape(input=quaternion)[:-1], (3, 3)), axis=-1)
|
130 |
+
return tf.reshape(matrix, shape=output_shape)
|
131 |
+
|
132 |
+
|
133 |
+
"""
|
134 |
+
Below is a continuous 6D rotation representation adapted from
|
135 |
+
On the Continuity of Rotation Representations in Neural Networks
|
136 |
+
https://arxiv.org/pdf/1812.07035.pdf
|
137 |
+
https://github.com/papagina/RotationContinuity/blob/master/sanity_test/code/tools.py
|
138 |
+
"""
|
139 |
+
def rotation_matrix_to_ortho6d(matrix: tf.Tensor) -> tf.Tensor:
|
140 |
+
"""
|
141 |
+
The orhto6d represents the first two column vectors a1 and a2 of the
|
142 |
+
rotation matrix: [ | , |, | ]
|
143 |
+
[ a1, a2, a3]
|
144 |
+
[ | , |, | ]
|
145 |
+
Input: (A1, ..., An, 3, 3)
|
146 |
+
Output: (A1, ..., An, 6)
|
147 |
+
"""
|
148 |
+
ortho6d = matrix[..., :, :2]
|
149 |
+
# Transpose the last two dimension
|
150 |
+
perm = list(range(len(ortho6d.shape)))
|
151 |
+
perm[-2], perm[-1] = perm[-1], perm[-2]
|
152 |
+
ortho6d = tf.transpose(ortho6d, perm)
|
153 |
+
# Flatten the last two dimension
|
154 |
+
ortho6d = tf.reshape(ortho6d, ortho6d.shape[:-2] + [6])
|
155 |
+
return ortho6d
|
156 |
+
|
157 |
+
|
158 |
+
def rotation_matrix_to_ortho6d_1d(matrix: tf.Tensor) -> tf.Tensor:
|
159 |
+
"""
|
160 |
+
The orhto6d represents the first two column vectors a1 and a2 of the
|
161 |
+
rotation matrix: [ | , |, | ]
|
162 |
+
[ a1, a2, a3]
|
163 |
+
[ | , |, | ]
|
164 |
+
Input: (3, 3)
|
165 |
+
Output: (6,)
|
166 |
+
This function is used to make tensorflow happy.
|
167 |
+
"""
|
168 |
+
ortho6d = matrix[:, :2]
|
169 |
+
# Transpose the last two dimension
|
170 |
+
ortho6d = tf.transpose(ortho6d)
|
171 |
+
# Flatten the last two dimension
|
172 |
+
ortho6d = tf.reshape(ortho6d, [6])
|
173 |
+
return ortho6d
|
174 |
+
|
175 |
+
|
176 |
+
def normalize_vector(v):
|
177 |
+
"""
|
178 |
+
v: (..., N)
|
179 |
+
"""
|
180 |
+
v_mag = tf.sqrt(tf.reduce_sum(tf.square(v), axis=-1, keepdims=True))
|
181 |
+
v_mag = tf.maximum(v_mag, 1e-8)
|
182 |
+
v_normalized = v / v_mag
|
183 |
+
|
184 |
+
return v_normalized
|
185 |
+
|
186 |
+
|
187 |
+
def cross_product(u, v):
|
188 |
+
"""
|
189 |
+
u: (..., 3)
|
190 |
+
v: (..., 3)
|
191 |
+
u x v: (..., 3)
|
192 |
+
"""
|
193 |
+
i = u[..., 1] * v[..., 2] - u[..., 2] * v[..., 1]
|
194 |
+
j = u[..., 2] * v[..., 0] - u[..., 0] * v[..., 2]
|
195 |
+
k = u[..., 0] * v[..., 1] - u[..., 1] * v[..., 0]
|
196 |
+
out = tf.stack([i, j, k], axis=-1)
|
197 |
+
return out
|
198 |
+
|
199 |
+
|
200 |
+
def ortho6d_to_rotation_matrix(ortho6d: tf.Tensor) -> tf.Tensor:
|
201 |
+
"""
|
202 |
+
The orhto6d represents the first two column vectors a1 and a2 of the
|
203 |
+
rotation matrix: [ | , |, | ]
|
204 |
+
[ a1, a2, a3]
|
205 |
+
[ | , |, | ]
|
206 |
+
Input: (A1, ..., An, 6)
|
207 |
+
Output: (A1, ..., An, 3, 3)
|
208 |
+
"""
|
209 |
+
x_raw = ortho6d[..., 0:3]
|
210 |
+
y_raw = ortho6d[..., 3:6]
|
211 |
+
|
212 |
+
x = normalize_vector(x_raw)
|
213 |
+
z = cross_product(x, y_raw)
|
214 |
+
z = normalize_vector(z)
|
215 |
+
y = cross_product(z, x)
|
216 |
+
|
217 |
+
# Stack x, y, z to form the matrix
|
218 |
+
matrix = tf.stack([x, y, z], axis=-1)
|
219 |
+
return matrix
|
220 |
+
|
221 |
+
|
222 |
+
def capitalize_and_period(instr: str) -> str:
|
223 |
+
"""
|
224 |
+
Capitalize the first letter of a string and add a period to the end if it's not there.
|
225 |
+
"""
|
226 |
+
if len(instr) > 0:
|
227 |
+
# if the first letter is not capital, make it so
|
228 |
+
if not instr[0].isupper():
|
229 |
+
# if the first letter is not capital, make it so
|
230 |
+
instr = instr[0].upper() + instr[1:]
|
231 |
+
# add period to the end if it's not there
|
232 |
+
if instr[-1] != '.':
|
233 |
+
# add period to the end if it's not there
|
234 |
+
instr = instr + '.'
|
235 |
+
return instr
|
data/vla_dataset.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import random
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import tensorflow as tf
|
6 |
+
import tensorflow_datasets as tfds
|
7 |
+
import yaml
|
8 |
+
|
9 |
+
from data.episode_transform import process_episode, flatten_episode, \
|
10 |
+
flatten_episode_agilex, bgr_to_rgb
|
11 |
+
from data.utils import dataset_to_path
|
12 |
+
from data.preprocess_scripts import *
|
13 |
+
|
14 |
+
# Producer does not need GPU
|
15 |
+
tf.config.set_visible_devices([], 'GPU')
|
16 |
+
|
17 |
+
OPENX_EMBOD_DIR = 'data/datasets/openx_embod'
|
18 |
+
|
19 |
+
DATASET_NAMES_NOOPENX = [
|
20 |
+
"aloha_mobile",
|
21 |
+
"aloha_static",
|
22 |
+
"roboset",
|
23 |
+
"agilex",
|
24 |
+
"rh20t",
|
25 |
+
'calvin',
|
26 |
+
"bridgev2"
|
27 |
+
]
|
28 |
+
|
29 |
+
# Read the config
|
30 |
+
with open('configs/base.yaml', 'r') as file:
|
31 |
+
config = yaml.safe_load(file)
|
32 |
+
# Load some constants from the config
|
33 |
+
EPSD_LEN_THRESH_LOW = config['dataset']['epsd_len_thresh_low']
|
34 |
+
EPSD_LEN_THRESH_HIGH = config['dataset']['epsd_len_thresh_high']
|
35 |
+
# Read the image keys of each dataset
|
36 |
+
with open('configs/dataset_img_keys.json', 'r') as file:
|
37 |
+
IMAGE_KEYS = json.load(file)
|
38 |
+
|
39 |
+
|
40 |
+
class VLADataset:
|
41 |
+
"""
|
42 |
+
This class is used to sample episodes from the embododiment dataset.
|
43 |
+
"""
|
44 |
+
def __init__(self, seed, dataset_type, repeat=True):
|
45 |
+
'''
|
46 |
+
seed: the random seed
|
47 |
+
dataset_type: 'pretrain' or 'finetune', which dataset to load
|
48 |
+
repeat: whether to repeat to infinite length
|
49 |
+
'''
|
50 |
+
dataset_names_cfg = 'configs/pretrain_datasets.json' \
|
51 |
+
if dataset_type == "pretrain" else 'configs/finetune_datasets.json'
|
52 |
+
with open(dataset_names_cfg, 'r') as file:
|
53 |
+
DATASET_NAMES = json.load(file)
|
54 |
+
self.dataset_names = DATASET_NAMES
|
55 |
+
sample_weights_cfg = 'configs/pretrain_sample_weights.json' \
|
56 |
+
if dataset_type == "pretrain" else 'configs/finetune_sample_weights.json'
|
57 |
+
# Load the sample weights
|
58 |
+
with open(sample_weights_cfg, 'r') as file:
|
59 |
+
SAMPLE_WEIGHTS = json.load(file)
|
60 |
+
self.openx_dir = OPENX_EMBOD_DIR
|
61 |
+
self.epsd_len_thresh_low = EPSD_LEN_THRESH_LOW
|
62 |
+
self.epsd_len_thresh_high = EPSD_LEN_THRESH_HIGH
|
63 |
+
self.repeat = repeat
|
64 |
+
|
65 |
+
# Set the random seed
|
66 |
+
tf.random.set_seed(seed)
|
67 |
+
np.random.seed(seed)
|
68 |
+
|
69 |
+
# Weights of the each dataset in the collection to sample from
|
70 |
+
sample_weights = []
|
71 |
+
|
72 |
+
self.name2dataset = {}
|
73 |
+
for dataset_name in self.dataset_names:
|
74 |
+
if dataset_name in DATASET_NAMES_NOOPENX:
|
75 |
+
dataset = globals()[dataset_name].load_dataset(seed)
|
76 |
+
else:
|
77 |
+
dataset_path = dataset_to_path(dataset_name, self.openx_dir)
|
78 |
+
dataset = tfds.builder_from_directory(builder_dir=dataset_path)
|
79 |
+
dataset = dataset.as_dataset(split='all', shuffle_files=True)
|
80 |
+
|
81 |
+
# You can add filter for other datasets
|
82 |
+
if dataset_name == 'kuka':
|
83 |
+
dataset = dataset.filter(
|
84 |
+
lambda x: x['success'])
|
85 |
+
elif dataset_name == 'bc_z':
|
86 |
+
dataset = dataset.filter(
|
87 |
+
lambda x: tf.math.greater(
|
88 |
+
next(iter(x['steps']))['observation']['episode_success'], 0.5))
|
89 |
+
elif dataset_name == 'ucsd_pick_and_place_dataset_converted_externally_to_rlds':
|
90 |
+
dataset = dataset.filter(
|
91 |
+
lambda x: x['episode_metadata']['success'])
|
92 |
+
elif dataset_name == 'utokyo_xarm_bimanual_converted_externally_to_rlds':
|
93 |
+
# Only preserve the meaningful episodes
|
94 |
+
dataset = dataset.filter(
|
95 |
+
lambda x: tf.math.equal(
|
96 |
+
next(iter(x['steps']))['language_instruction'],
|
97 |
+
tf.constant('Unfold a wrinkled towel.')))
|
98 |
+
|
99 |
+
# Note: use cache() will cause the unexpected crash
|
100 |
+
# dataset = dataset.map().cache().shuffle().repeat()
|
101 |
+
print(dataset_name)
|
102 |
+
dataset = dataset\
|
103 |
+
.map(
|
104 |
+
lambda x: process_episode(x, dataset_name,
|
105 |
+
IMAGE_KEYS[dataset_name]['image_keys'],
|
106 |
+
IMAGE_KEYS[dataset_name]['image_mask'])
|
107 |
+
)
|
108 |
+
|
109 |
+
# Change BGR to RGB if needed
|
110 |
+
if dataset_name == 'fmb':
|
111 |
+
dataset = dataset.map(bgr_to_rgb)
|
112 |
+
|
113 |
+
if self.repeat:
|
114 |
+
dataset = dataset.repeat()
|
115 |
+
self.name2dataset[dataset_name] = iter(dataset)
|
116 |
+
print(SAMPLE_WEIGHTS)
|
117 |
+
sample_weights.append(SAMPLE_WEIGHTS[dataset_name])
|
118 |
+
# Normalize the sample weights
|
119 |
+
sample_weights = np.array(sample_weights)
|
120 |
+
self.sample_weights = sample_weights / np.sum(sample_weights)
|
121 |
+
|
122 |
+
def __iter__(self):
|
123 |
+
'''
|
124 |
+
Sample batches of episodes for an epoch.
|
125 |
+
'''
|
126 |
+
while True:
|
127 |
+
dataset_name = np.random.choice(self.dataset_names, p=self.sample_weights)
|
128 |
+
episode = next(self.name2dataset[dataset_name])
|
129 |
+
if dataset_name == "agilex":
|
130 |
+
episode_steps = flatten_episode_agilex(episode)
|
131 |
+
else:
|
132 |
+
episode_steps = flatten_episode(episode)
|
133 |
+
# Filter too short
|
134 |
+
if len(episode_steps) < self.epsd_len_thresh_low:
|
135 |
+
continue
|
136 |
+
# Randomly sample too long
|
137 |
+
if len(episode_steps) > self.epsd_len_thresh_high:
|
138 |
+
episode_steps = random.sample(episode_steps, self.epsd_len_thresh_high)
|
139 |
+
|
140 |
+
yield episode_steps
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == "__main__":
|
144 |
+
dataset = VLADataset(0, 'finetune')
|
145 |
+
for episode in dataset:
|
146 |
+
print(episode[0])
|
147 |
+
break
|
encode_lang.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
from models.multimodal_encoder.t5_encoder import T5Embedder
|
7 |
+
|
8 |
+
|
9 |
+
GPU = 0
|
10 |
+
MODEL_PATH = "google/t5-v1_1-xxl"
|
11 |
+
CONFIG_PATH = "configs/base.yaml"
|
12 |
+
SAVE_DIR = "lang_embed/"
|
13 |
+
|
14 |
+
# Modify this to your task name and instruction
|
15 |
+
TASK_NAME = "anubis_carrot_to_bag"
|
16 |
+
# INSTRUCTION = "take the towel off the kirby doll"
|
17 |
+
# INSTRUCTION = "insert the brush to the dustpan"
|
18 |
+
INSTRUCTION = "pick up the carrot and put into the bag"
|
19 |
+
|
20 |
+
# Note: if your GPU VRAM is less than 24GB,
|
21 |
+
# it is recommended to enable offloading by specifying an offload directory.
|
22 |
+
# OFFLOAD_DIR = '/home/jellyho/OFFLOAD' # Specify your offload directory here, ensuring the directory exists.
|
23 |
+
|
24 |
+
def main():
|
25 |
+
with open(CONFIG_PATH, "r") as fp:
|
26 |
+
config = yaml.safe_load(fp)
|
27 |
+
|
28 |
+
device = torch.device(f"cuda:{GPU}")
|
29 |
+
text_embedder = T5Embedder(
|
30 |
+
from_pretrained=MODEL_PATH,
|
31 |
+
model_max_length=config["dataset"]["tokenizer_max_length"],
|
32 |
+
device=device,
|
33 |
+
# use_offload_folder=OFFLOAD_DIR
|
34 |
+
)
|
35 |
+
tokenizer, text_encoder = text_embedder.tokenizer, text_embedder.model
|
36 |
+
|
37 |
+
tokens = tokenizer(
|
38 |
+
INSTRUCTION, return_tensors="pt",
|
39 |
+
padding="longest",
|
40 |
+
truncation=True
|
41 |
+
)["input_ids"].to(device)
|
42 |
+
|
43 |
+
tokens = tokens.view(1, -1)
|
44 |
+
with torch.no_grad():
|
45 |
+
pred = text_encoder(tokens).last_hidden_state.detach().cpu()
|
46 |
+
|
47 |
+
save_path = os.path.join(SAVE_DIR, f"{TASK_NAME}.pt")
|
48 |
+
# We save the embeddings in a dictionary format
|
49 |
+
torch.save({
|
50 |
+
"name": TASK_NAME,
|
51 |
+
"instruction": INSTRUCTION,
|
52 |
+
"embeddings": pred
|
53 |
+
}, save_path
|
54 |
+
)
|
55 |
+
|
56 |
+
print(f'\"{INSTRUCTION}\" from \"{TASK_NAME}\" is encoded by \"{MODEL_PATH}\" into shape {pred.shape} and saved to \"{save_path}\"')
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
main()
|
finetune.sh
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
|
2 |
+
export NCCL_IB_DISABLE=0
|
3 |
+
export NCCL_SOCKET_IFNAME=eth1
|
4 |
+
export NCCL_DEBUG=INFO
|
5 |
+
export NCCL_NVLS_ENABLE=0
|
6 |
+
export MASTER_PORT=$2
|
7 |
+
export TEXT_ENCODER_NAME="google/t5-v1_1-xxl"
|
8 |
+
export VISION_ENCODER_NAME="google/siglip-so400m-patch14-384"
|
9 |
+
export OUTPUT_DIR="./checkpoints/$1"
|
10 |
+
export CFLAGS="-I/usr/include"
|
11 |
+
export LDFLAGS="-L/usr/lib/x86_64-linux-gnu"
|
12 |
+
export CUTLASS_PATH="/home/jellyho/cutlass"
|
13 |
+
|
14 |
+
export WANDB_PROJECT="robotics_diffusion_transformer"
|
15 |
+
|
16 |
+
if [ ! -d "$OUTPUT_DIR" ]; then
|
17 |
+
mkdir "$OUTPUT_DIR"
|
18 |
+
echo "Folder '$OUTPUT_DIR' created"
|
19 |
+
else
|
20 |
+
echo "Folder '$OUTPUT_DIR' already exists"
|
21 |
+
fi
|
22 |
+
|
23 |
+
# For run in a single node/machine
|
24 |
+
# accelerate launch main.py \
|
25 |
+
# --deepspeed="./configs/zero2.json" \
|
26 |
+
# ...
|
27 |
+
# --hostfile=hostfile.txt
|
28 |
+
|
29 |
+
accelerate launch --main_process_port $2 --num_processes 2 --num_machines 1 --mixed_precision bf16 main.py \
|
30 |
+
--deepspeed="./configs/zero2.json" \
|
31 |
+
--pretrained_model_name_or_path="robotics-diffusion-transformer/rdt-1b" \
|
32 |
+
--pretrained_text_encoder_name_or_path=$TEXT_ENCODER_NAME \
|
33 |
+
--pretrained_vision_encoder_name_or_path=$VISION_ENCODER_NAME \
|
34 |
+
--output_dir=$OUTPUT_DIR \
|
35 |
+
--train_batch_size=8 \
|
36 |
+
--sample_batch_size=8 \
|
37 |
+
--max_train_steps=50000 \
|
38 |
+
--checkpointing_period=5000 \
|
39 |
+
--sample_period=1000 \
|
40 |
+
--checkpoints_total_limit=10 \
|
41 |
+
--lr_scheduler="constant" \
|
42 |
+
--learning_rate=1e-4 \
|
43 |
+
--mixed_precision="bf16" \
|
44 |
+
--dataloader_num_workers=16 \
|
45 |
+
--image_aug \
|
46 |
+
--dataset_type="finetune" \
|
47 |
+
--gradient_accumulation_steps 1 \
|
48 |
+
--report_to=wandb \
|
49 |
+
--load_from_hdf5 \
|
50 |
+
--dataset_name $1 \
|
51 |
+
--precomp_lang_embed
|
52 |
+
# --resume_from_checkpoint="checkpoint-50000"
|
53 |
+
|
54 |
+
# Use this to resume training from some previous checkpoint
|
55 |
+
# --resume_from_checkpoint="checkpoint-36000" \
|
56 |
+
# Use this to load from saved lanuage instruction embeddings,
|
57 |
+
# instead of calculating it during training
|
finetune_maniskill.sh
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export NCCL_IB_HCA=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
|
2 |
+
export NCCL_IB_DISABLE=0
|
3 |
+
export NCCL_SOCKET_IFNAME=bond0
|
4 |
+
export NCCL_DEBUG=INFO
|
5 |
+
export NCCL_NVLS_ENABLE=0
|
6 |
+
|
7 |
+
export TEXT_ENCODER_NAME="google/t5-v1_1-xxl"
|
8 |
+
export VISION_ENCODER_NAME="google/siglip-so400m-patch14-384"
|
9 |
+
export OUTPUT_DIR="./checkpoints/rdt-finetune-1b-sim"
|
10 |
+
export CFLAGS="-I/usr/include"
|
11 |
+
export LDFLAGS="-L/usr/lib/x86_64-linux-gnu"
|
12 |
+
export CUTLASS_PATH="/data/lingxuan/cutlass"
|
13 |
+
|
14 |
+
export WANDB_PROJECT="robotic_diffusion_transformer"
|
15 |
+
|
16 |
+
if [ ! -d "$OUTPUT_DIR" ]; then
|
17 |
+
mkdir "$OUTPUT_DIR"
|
18 |
+
echo "Folder '$OUTPUT_DIR' created"
|
19 |
+
else
|
20 |
+
echo "Folder '$OUTPUT_DIR' already exists"
|
21 |
+
fi
|
22 |
+
# For run in a single node/machine
|
23 |
+
# accelerate launch main.py \
|
24 |
+
# --deepspeed="./configs/zero2.json" \
|
25 |
+
# ...
|
26 |
+
|
27 |
+
accelerate launch main.py \
|
28 |
+
--deepspeed="./configs/zero2.json" \
|
29 |
+
--pretrained_model_name_or_path="robotics-diffusion-transformer/rdt-1b" \
|
30 |
+
--pretrained_text_encoder_name_or_path=$TEXT_ENCODER_NAME \
|
31 |
+
--pretrained_vision_encoder_name_or_path=$VISION_ENCODER_NAME \
|
32 |
+
--output_dir=$OUTPUT_DIR \
|
33 |
+
--train_batch_size=24 \
|
34 |
+
--sample_batch_size=32 \
|
35 |
+
--max_train_steps=400000 \
|
36 |
+
--checkpointing_period=10000 \
|
37 |
+
--sample_period=500 \
|
38 |
+
--checkpoints_total_limit=40 \
|
39 |
+
--lr_scheduler="constant" \
|
40 |
+
--learning_rate=1e-4 \
|
41 |
+
--mixed_precision="bf16" \
|
42 |
+
--dataloader_num_workers=8 \
|
43 |
+
--image_aug \
|
44 |
+
--dataset_type="finetune" \
|
45 |
+
--state_noise_snr=40 \
|
46 |
+
--load_from_hdf5 \
|
47 |
+
--report_to=wandb
|
48 |
+
|
inference.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python -m scripts.agilex_inference \
|
2 |
+
--use_actions_interpolation \
|
3 |
+
--pretrained_model_name_or_path="checkpoints/your_finetuned_ckpt.pt" \ # your finetuned checkpoint: e.g., checkpoints/rdt-finetune-1b/checkpoint-<STEP NUMBER>, checkpoints/rdt-finetune-1b/checkpoint-<STEP NUMBER>/pytorch_model/mp_rank_00_model_states.pt,
|
4 |
+
--lang_embeddings_path="outs/lang_embeddings/your_instr.pt" \
|
5 |
+
--ctrl_freq=25 # your control frequency
|
main.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from train.train import train
|
4 |
+
|
5 |
+
from accelerate.logging import get_logger
|
6 |
+
|
7 |
+
|
8 |
+
def parse_args(input_args=None):
|
9 |
+
parser = argparse.ArgumentParser(description="Main script for training RDT.")
|
10 |
+
parser.add_argument(
|
11 |
+
"--config_path",
|
12 |
+
type=str,
|
13 |
+
default="configs/base.yaml",
|
14 |
+
help="Path to the configuration file. Default is `configs/base.yaml`.",
|
15 |
+
)
|
16 |
+
parser.add_argument(
|
17 |
+
"--deepspeed",
|
18 |
+
type=str,
|
19 |
+
default=None,
|
20 |
+
help="Enable DeepSpeed and pass the path to its config file or an already initialized DeepSpeed config dictionary",
|
21 |
+
)
|
22 |
+
parser.add_argument(
|
23 |
+
"--pretrained_text_encoder_name_or_path",
|
24 |
+
type=str,
|
25 |
+
default=None,
|
26 |
+
help="Pretrained text encoder name or path if not the same as model_name",
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
"--pretrained_vision_encoder_name_or_path",
|
30 |
+
type=str,
|
31 |
+
default=None,
|
32 |
+
help="Pretrained vision encoder name or path if not the same as model_name",
|
33 |
+
)
|
34 |
+
|
35 |
+
parser.add_argument(
|
36 |
+
"--output_dir",
|
37 |
+
type=str,
|
38 |
+
default="checkpoints",
|
39 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
40 |
+
)
|
41 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
42 |
+
|
43 |
+
parser.add_argument(
|
44 |
+
"--load_from_hdf5",
|
45 |
+
action="store_true",
|
46 |
+
default=False,
|
47 |
+
help=(
|
48 |
+
"Whether to load the dataset directly from HDF5 files. "
|
49 |
+
"If False, the dataset will be loaded using producer-consumer pattern, "
|
50 |
+
"where the producer reads TFRecords and saves them to buffer, and the consumer reads from buffer."
|
51 |
+
)
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--sample_batch_size", type=int, default=8, help="Batch size (per device) for the sampling dataloader."
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
"--num_sample_batches", type=int, default=2, help="Number of batches to sample from the dataset."
|
61 |
+
)
|
62 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
63 |
+
parser.add_argument(
|
64 |
+
"--max_train_steps",
|
65 |
+
type=int,
|
66 |
+
default=None,
|
67 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--checkpointing_period",
|
71 |
+
type=int,
|
72 |
+
default=500,
|
73 |
+
help=(
|
74 |
+
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
|
75 |
+
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
|
76 |
+
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
|
77 |
+
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
|
78 |
+
"instructions."
|
79 |
+
),
|
80 |
+
)
|
81 |
+
parser.add_argument(
|
82 |
+
"--checkpoints_total_limit",
|
83 |
+
type=int,
|
84 |
+
default=None,
|
85 |
+
help=(
|
86 |
+
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
87 |
+
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
88 |
+
" for more details"
|
89 |
+
),
|
90 |
+
)
|
91 |
+
parser.add_argument(
|
92 |
+
"--resume_from_checkpoint",
|
93 |
+
type=str,
|
94 |
+
default=None,
|
95 |
+
help=(
|
96 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
97 |
+
' `--checkpointing_period`, or `"latest"` to automatically select the last available checkpoint.'
|
98 |
+
),
|
99 |
+
)
|
100 |
+
parser.add_argument(
|
101 |
+
"--pretrained_model_name_or_path",
|
102 |
+
type=str,
|
103 |
+
default=None,
|
104 |
+
help=(
|
105 |
+
"Path or name of a pretrained checkpoint to load the model from.\n",
|
106 |
+
" This can be either:\n"
|
107 |
+
" - a string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co, e.g., `robotics-diffusion-transformer/rdt-1b`,\n"
|
108 |
+
" - a path to a *directory* containing model weights saved using [`~RDTRunner.save_pretrained`] method, e.g., `./my_model_directory/`.\n"
|
109 |
+
" - a path to model checkpoint (*.pt), .e.g, `my_model_directory/checkpoint-10000/pytorch_model/mp_rank_00_model_states.pt`"
|
110 |
+
" - `None` if you are randomly initializing model using configuration at `config_path`."
|
111 |
+
)
|
112 |
+
)
|
113 |
+
parser.add_argument(
|
114 |
+
"--gradient_accumulation_steps",
|
115 |
+
type=int,
|
116 |
+
default=1,
|
117 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
118 |
+
)
|
119 |
+
parser.add_argument(
|
120 |
+
"--gradient_checkpointing",
|
121 |
+
action="store_true",
|
122 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--learning_rate",
|
126 |
+
type=float,
|
127 |
+
default=5e-6,
|
128 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--cond_mask_prob",
|
132 |
+
type=float,
|
133 |
+
default=0.1,
|
134 |
+
help=(
|
135 |
+
"The probability to randomly mask the conditions (except states) during training. "
|
136 |
+
"If set to 0, the conditions are not masked."
|
137 |
+
),
|
138 |
+
)
|
139 |
+
parser.add_argument(
|
140 |
+
"--cam_ext_mask_prob",
|
141 |
+
type=float,
|
142 |
+
default=-1.0,
|
143 |
+
help=(
|
144 |
+
"The probability to randomly mask the external camera image during training. "
|
145 |
+
"If set to < 0, the external camera image is masked with the probability of `cond_mask_prob`."
|
146 |
+
),
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--state_noise_snr",
|
150 |
+
type=float,
|
151 |
+
default=None,
|
152 |
+
help=(
|
153 |
+
"The signal-to-noise ratio (SNR, unit: dB) for adding noise to the states. "
|
154 |
+
"Default is None, which means no noise is added."
|
155 |
+
),
|
156 |
+
)
|
157 |
+
parser.add_argument(
|
158 |
+
"--image_aug",
|
159 |
+
action="store_true",
|
160 |
+
default=False,
|
161 |
+
help="Whether or not to apply image augmentation (ColorJitter, blur, noise, etc) to the input images.",
|
162 |
+
)
|
163 |
+
parser.add_argument(
|
164 |
+
"--precomp_lang_embed",
|
165 |
+
action="store_true",
|
166 |
+
default=False,
|
167 |
+
help="Whether or not to use precomputed language embeddings.",
|
168 |
+
)
|
169 |
+
parser.add_argument(
|
170 |
+
"--scale_lr",
|
171 |
+
action="store_true",
|
172 |
+
default=False,
|
173 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"--lr_scheduler",
|
177 |
+
type=str,
|
178 |
+
default="constant",
|
179 |
+
help=(
|
180 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
181 |
+
' "constant", "constant_with_warmup"]'
|
182 |
+
),
|
183 |
+
)
|
184 |
+
parser.add_argument(
|
185 |
+
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
186 |
+
)
|
187 |
+
parser.add_argument(
|
188 |
+
"--lr_num_cycles",
|
189 |
+
type=int,
|
190 |
+
default=1,
|
191 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
192 |
+
)
|
193 |
+
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
194 |
+
parser.add_argument(
|
195 |
+
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
196 |
+
)
|
197 |
+
parser.add_argument(
|
198 |
+
"--dataloader_num_workers",
|
199 |
+
type=int,
|
200 |
+
default=0,
|
201 |
+
help=(
|
202 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
203 |
+
),
|
204 |
+
)
|
205 |
+
parser.add_argument("--alpha", type=float, default=0.9, help="The moving average coefficient for each dataset's loss.")
|
206 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
207 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
208 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
209 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
210 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
211 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
212 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
213 |
+
parser.add_argument(
|
214 |
+
"--hub_model_id",
|
215 |
+
type=str,
|
216 |
+
default=None,
|
217 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
218 |
+
)
|
219 |
+
parser.add_argument(
|
220 |
+
"--logging_dir",
|
221 |
+
type=str,
|
222 |
+
default="logs",
|
223 |
+
help=(
|
224 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
225 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
226 |
+
),
|
227 |
+
)
|
228 |
+
parser.add_argument(
|
229 |
+
"--allow_tf32",
|
230 |
+
action="store_true",
|
231 |
+
help=(
|
232 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
233 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
234 |
+
),
|
235 |
+
)
|
236 |
+
parser.add_argument(
|
237 |
+
"--report_to",
|
238 |
+
type=str,
|
239 |
+
default="tensorboard",
|
240 |
+
help=(
|
241 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
242 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
243 |
+
),
|
244 |
+
)
|
245 |
+
parser.add_argument(
|
246 |
+
"--sample_period",
|
247 |
+
type=int,
|
248 |
+
default=-1,
|
249 |
+
help=(
|
250 |
+
"Run sampling every X steps. During the sampling phase, the model will sample a trajectory"
|
251 |
+
" and report the error between the sampled trajectory and groud-truth trajectory"
|
252 |
+
" in the training batch."
|
253 |
+
),
|
254 |
+
)
|
255 |
+
parser.add_argument(
|
256 |
+
"--mixed_precision",
|
257 |
+
type=str,
|
258 |
+
default=None,
|
259 |
+
choices=["no", "fp16", "bf16"],
|
260 |
+
help=(
|
261 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
262 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
263 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
264 |
+
),
|
265 |
+
)
|
266 |
+
|
267 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
268 |
+
parser.add_argument(
|
269 |
+
"--set_grads_to_none",
|
270 |
+
action="store_true",
|
271 |
+
help=(
|
272 |
+
"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
|
273 |
+
" behaviors, so disable this argument if it causes any problems. More info:"
|
274 |
+
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
|
275 |
+
),
|
276 |
+
)
|
277 |
+
|
278 |
+
parser.add_argument('--dataset_type',
|
279 |
+
type=str,
|
280 |
+
default="pretrain",
|
281 |
+
required=False,
|
282 |
+
help="Whether to load the pretrain dataset or finetune dataset."
|
283 |
+
)
|
284 |
+
parser.add_argument('--dataset_name', type=str)
|
285 |
+
|
286 |
+
if input_args is not None:
|
287 |
+
args = parser.parse_args(input_args)
|
288 |
+
else:
|
289 |
+
args = parser.parse_args()
|
290 |
+
|
291 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
292 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
293 |
+
args.local_rank = env_local_rank
|
294 |
+
|
295 |
+
return args
|
296 |
+
|
297 |
+
|
298 |
+
if __name__ == "__main__":
|
299 |
+
logger = get_logger(__name__)
|
300 |
+
args = parse_args()
|
301 |
+
train(args, logger)
|
models/ema_model.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Reference: DiffusionPolicy [https://github.com/real-stanford/diffusion_policy]
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
5 |
+
|
6 |
+
|
7 |
+
class EMAModel:
|
8 |
+
"""
|
9 |
+
Exponential Moving Average of models weights
|
10 |
+
"""
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
model,
|
14 |
+
update_after_step=0,
|
15 |
+
inv_gamma=1.0,
|
16 |
+
power=2 / 3,
|
17 |
+
min_value=0.0,
|
18 |
+
max_value=0.9999
|
19 |
+
):
|
20 |
+
"""
|
21 |
+
@crowsonkb's notes on EMA Warmup:
|
22 |
+
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
23 |
+
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
24 |
+
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
25 |
+
at 215.4k steps).
|
26 |
+
Args:
|
27 |
+
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
28 |
+
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
29 |
+
min_value (float): The minimum EMA decay rate. Default: 0.
|
30 |
+
"""
|
31 |
+
|
32 |
+
self.averaged_model = model
|
33 |
+
self.averaged_model.eval()
|
34 |
+
self.averaged_model.requires_grad_(False)
|
35 |
+
|
36 |
+
self.update_after_step = update_after_step
|
37 |
+
self.inv_gamma = inv_gamma
|
38 |
+
self.power = power
|
39 |
+
self.min_value = min_value
|
40 |
+
self.max_value = max_value
|
41 |
+
|
42 |
+
self.decay = 0.0
|
43 |
+
self.optimization_step = 0
|
44 |
+
|
45 |
+
def get_decay(self, optimization_step):
|
46 |
+
"""
|
47 |
+
Compute the decay factor for the exponential moving average.
|
48 |
+
"""
|
49 |
+
step = max(0, optimization_step - self.update_after_step - 1)
|
50 |
+
value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
51 |
+
|
52 |
+
if step <= 0:
|
53 |
+
return 0.0
|
54 |
+
|
55 |
+
return max(self.min_value, min(value, self.max_value))
|
56 |
+
|
57 |
+
@torch.no_grad()
|
58 |
+
def step(self, new_model):
|
59 |
+
self.decay = self.get_decay(self.optimization_step)
|
60 |
+
|
61 |
+
# old_all_dataptrs = set()
|
62 |
+
# for param in new_model.parameters():
|
63 |
+
# data_ptr = param.data_ptr()
|
64 |
+
# if data_ptr != 0:
|
65 |
+
# old_all_dataptrs.add(data_ptr)
|
66 |
+
|
67 |
+
all_dataptrs = set()
|
68 |
+
for module, ema_module in zip(new_model.modules(), self.averaged_model.modules()):
|
69 |
+
for param, ema_param in zip(module.parameters(recurse=False), ema_module.parameters(recurse=False)):
|
70 |
+
# iterative over immediate parameters only.
|
71 |
+
if isinstance(param, dict):
|
72 |
+
raise RuntimeError('Dict parameter not supported')
|
73 |
+
|
74 |
+
# data_ptr = param.data_ptr()
|
75 |
+
# if data_ptr != 0:
|
76 |
+
# all_dataptrs.add(data_ptr)
|
77 |
+
|
78 |
+
if isinstance(module, _BatchNorm):
|
79 |
+
# skip batchnorms
|
80 |
+
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
81 |
+
elif not param.requires_grad:
|
82 |
+
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
83 |
+
else:
|
84 |
+
ema_param.mul_(self.decay)
|
85 |
+
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
|
86 |
+
|
87 |
+
# verify that iterating over module and then parameters is identical to parameters recursively.
|
88 |
+
# assert old_all_dataptrs == all_dataptrs
|
89 |
+
self.optimization_step += 1
|