|
import traceback |
|
import time |
|
import os |
|
import json |
|
import math |
|
import random |
|
from typing import Dict, Sequence |
|
|
|
import numpy as np |
|
import torch |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
from PIL import Image |
|
import transformers |
|
|
|
from data.filelock import FileLock |
|
from data.hdf5_vla_dataset import TabletopHDF5VLADataset, AnubisHDF5VLADataset |
|
from train.image_corrupt import image_corrupt |
|
|
|
|
|
def get_clean_item(chunk_dir): |
|
""" |
|
Get indexes of clean items in a chunk. |
|
""" |
|
dirty_bit = read_dirty_bit(chunk_dir) |
|
return np.where(1 - dirty_bit)[0].tolist() |
|
|
|
|
|
def save_dirty_bit(chunk_dir, dirty_bit): |
|
""" |
|
Save the dirty bit to the chunk directory. |
|
""" |
|
time_stmp = time.time() |
|
while time.time() - time_stmp < 10.0: |
|
try: |
|
file_path = os.path.join(chunk_dir, "dirty_bit") |
|
lock = FileLock(file_path) |
|
lock.acquire_write_lock() |
|
with open(file_path, 'wb') as file: |
|
file.write(dirty_bit.tobytes()) |
|
lock.release_lock() |
|
return |
|
except KeyboardInterrupt: |
|
lock.release_lock() |
|
raise KeyboardInterrupt |
|
except BaseException: |
|
lock.release_lock() |
|
continue |
|
raise RuntimeError("Failed to save dirty bit.") |
|
|
|
|
|
def read_dirty_bit(chunk_dir): |
|
""" |
|
Read the dirty bit from the chunk directory. |
|
""" |
|
|
|
time_stmp = time.time() |
|
while time.time() - time_stmp < 10.0: |
|
try: |
|
file_path = os.path.join(chunk_dir, "dirty_bit") |
|
lock = FileLock(file_path) |
|
lock.acquire_read_lock() |
|
with open(file_path, 'rb') as file: |
|
dirty_bit = np.frombuffer(file.read(), dtype=np.uint8).copy() |
|
lock.release_lock() |
|
assert len(dirty_bit) > 0 |
|
return dirty_bit |
|
except KeyboardInterrupt: |
|
lock.release_lock() |
|
raise KeyboardInterrupt |
|
except BaseException: |
|
lock.release_lock() |
|
continue |
|
raise RuntimeError("Failed to read dirty bit.") |
|
|
|
|
|
class VLAConsumerDataset(Dataset): |
|
"""A vision-languange-action Dataset for supervised training. |
|
This dataset will load data from the buffer directory. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
config, |
|
tokenizer, |
|
image_processor, |
|
num_cameras, |
|
img_history_size, |
|
image_size=None, |
|
auto_adjust_image_brightness=False, |
|
image_aug=False, |
|
dataset_type='pretrain', |
|
cond_mask_prob=0.1, |
|
cam_ext_mask_prob=-1.0, |
|
state_noise_snr=None, |
|
use_hdf5=False, |
|
use_precomp_lang_embed=False, |
|
task_name=None |
|
): |
|
super(VLAConsumerDataset, self).__init__() |
|
|
|
|
|
with open("configs/dataset_control_freq.json", 'r') as fp: |
|
self.control_freq = json.load(fp) |
|
|
|
dataset_names_cfg = 'configs/pretrain_datasets.json' \ |
|
if dataset_type == 'pretrain' else 'configs/finetune_datasets.json' |
|
with open(dataset_names_cfg, 'r') as file: |
|
DATASET_NAMES = json.load(file) |
|
|
|
|
|
|
|
self.dataset_name2id = {task_name: 0} |
|
self.dataset_id2name = {0: task_name} |
|
|
|
self.image_processor = image_processor |
|
|
|
self.buffer_dir = config["buf_path"] |
|
self.num_chunks = config["buf_num_chunks"] |
|
self.chunk_size = config["buf_chunk_size"] |
|
self.tokenizer_max_length = config["tokenizer_max_length"] |
|
self.image_aspect_ratio = config["image_aspect_ratio"] |
|
self.state_noise_snr = state_noise_snr |
|
self.num_cameras = num_cameras |
|
self.img_history_size = img_history_size |
|
self.cond_mask_prob = cond_mask_prob |
|
self.cam_ext_mask_prob = cam_ext_mask_prob |
|
self.use_hdf5 = use_hdf5 |
|
self.hdf5_dataset = None |
|
if use_hdf5: |
|
self.hdf5_dataset = AnubisHDF5VLADataset(task_name) |
|
self.use_precomp_lang_embed = use_precomp_lang_embed |
|
if use_precomp_lang_embed: |
|
self.empty_lang_embed = torch.load("data/empty_lang_embed.pt") |
|
|
|
|
|
with open("configs/dataset_stat.json", 'r') as f: |
|
dataset_stat = json.load(f) |
|
self.dataset_stat = dataset_stat |
|
|
|
self.tokenizer = tokenizer |
|
self.image_size = image_size |
|
self.auto_adjust_image_brightness = auto_adjust_image_brightness |
|
self.image_aug = image_aug |
|
|
|
self.last_content = None |
|
self.last_meta = None |
|
|
|
def get_dataset_name2id(self): |
|
return self.dataset_name2id |
|
|
|
def get_dataset_id2name(self): |
|
return self.dataset_id2name |
|
|
|
@staticmethod |
|
def pairwise(iterable): |
|
a = iter(iterable) |
|
return zip(a, a) |
|
|
|
@staticmethod |
|
def _load_data_from_chunk(chunk_dir, chunk_item_idx): |
|
|
|
time_stmp = time.time() |
|
while time.time() - time_stmp < 10.0: |
|
try: |
|
locks = [] |
|
file_path = os.path.join(chunk_dir, f"json_content_{chunk_item_idx}.json") |
|
lock = FileLock(file_path) |
|
locks.append(lock) |
|
lock.acquire_read_lock() |
|
with open(file_path, 'r') as file: |
|
json_content = json.load(file) |
|
lock.release_lock() |
|
file_path = os.path.join(chunk_dir, f"sample_{chunk_item_idx}.npz") |
|
lock = FileLock(file_path) |
|
locks.append(lock) |
|
lock.acquire_read_lock() |
|
with open(file_path, 'rb') as file: |
|
sample_dict = np.load(file) |
|
meta = tuple(sample_dict.values()) |
|
lock.release_lock() |
|
return json_content, meta |
|
except KeyboardInterrupt: |
|
for lock in locks: |
|
lock.release_lock() |
|
raise KeyboardInterrupt |
|
except BaseException: |
|
for lock in locks: |
|
lock.release_lock() |
|
continue |
|
raise RuntimeError("Failed to load sample.") |
|
|
|
def __len__(self) -> int: |
|
if self.use_hdf5: |
|
return len(self.hdf5_dataset) |
|
else: |
|
return self.num_chunks * self.chunk_size |
|
|
|
def _safe_load(self, index): |
|
read_chunk_item_indices = [] |
|
|
|
read_chunk_idx = index // self.chunk_size |
|
while len(read_chunk_item_indices) == 0: |
|
read_chunk_dir = os.path.join(self.buffer_dir, f"chunk_{read_chunk_idx}") |
|
try: |
|
read_chunk_item_indices = get_clean_item(read_chunk_dir) |
|
except BaseException as e: |
|
|
|
print("Error catched when searching a clean chunk:", e) |
|
traceback.print_exc() |
|
read_chunk_item_indices = [] |
|
read_chunk_idx = (read_chunk_idx + 1) % self.num_chunks |
|
|
|
|
|
|
|
random_item_index = index % len(read_chunk_item_indices) |
|
read_chunk_item_index = read_chunk_item_indices[random_item_index] |
|
|
|
|
|
try: |
|
dirty_bit = read_dirty_bit(read_chunk_dir) |
|
dirty_bit[read_chunk_item_index] = 1 |
|
save_dirty_bit(read_chunk_dir, dirty_bit) |
|
except BaseException as e: |
|
|
|
print("Error catched when modifying the dirty bit:", e) |
|
traceback.print_exc() |
|
|
|
|
|
try: |
|
content, meta = self._load_data_from_chunk(read_chunk_dir, read_chunk_item_index) |
|
self.last_content, self.last_meta = content, meta |
|
except BaseException as e: |
|
|
|
print("Error catched when loading sample:", e) |
|
traceback.print_exc() |
|
|
|
|
|
content, meta = self.last_content, self.last_meta |
|
|
|
return (content, *meta) |
|
|
|
def __getitem__(self, index): |
|
|
|
while True: |
|
data_dict = None |
|
try: |
|
if self.use_hdf5: |
|
res = self.hdf5_dataset.get_item() |
|
content = res['meta'] |
|
states = res['state'] |
|
actions = res['actions'] |
|
state_elem_mask = res['state_indicator'] |
|
image_metas = [ |
|
res['cam_high'], res['cam_high_mask'], |
|
res['cam_right_wrist'], res['cam_right_wrist_mask'], |
|
res['cam_left_wrist'], res['cam_left_wrist_mask'], |
|
] |
|
state_std = res['state_std'] |
|
state_mean = res['state_mean'] |
|
state_norm = res['state_norm'] |
|
else: |
|
(content, _, states, _, actions, _, |
|
state_elem_mask, *image_metas, |
|
state_std, state_mean, state_norm) = self._safe_load(index) |
|
|
|
data_dict = {} |
|
data_dict['dataset_name'] = content['dataset_name'] |
|
data_dict['data_idx'] = self.dataset_name2id[data_dict['dataset_name']] |
|
data_dict['ctrl_freq'] = self.control_freq[data_dict['dataset_name']] \ |
|
if random.random() > self.cond_mask_prob else 0 |
|
|
|
if self.state_noise_snr is not None: |
|
states += np.random.normal( |
|
0.0, state_std / np.sqrt(10 ** (self.state_noise_snr / 10)), |
|
states.shape) |
|
ds_state_mean = np.array(self.dataset_stat[data_dict['dataset_name']]['state_mean']) |
|
ds_state_mean = np.tile(ds_state_mean[None], (states.shape[0], 1)) |
|
|
|
data_dict["states"] = states \ |
|
if random.random() > self.cond_mask_prob else ds_state_mean |
|
data_dict["actions"] = actions |
|
data_dict["state_elem_mask"] = state_elem_mask \ |
|
if random.random() > self.cond_mask_prob else np.zeros_like(state_elem_mask) |
|
|
|
|
|
data_dict["state_norm"] = state_norm |
|
|
|
|
|
|
|
background_color = np.array([ |
|
int(x*255) for x in self.image_processor.image_mean |
|
], dtype=np.uint8).reshape(1, 1, 3) |
|
background_image = np.ones(( |
|
self.image_processor.size["height"], |
|
self.image_processor.size["width"], 3), dtype=np.uint8 |
|
) * background_color |
|
|
|
image_metas = list(self.pairwise(image_metas)) |
|
mask_probs = [self.cond_mask_prob] * self.num_cameras |
|
if self.cam_ext_mask_prob >= 0.0: |
|
mask_probs[0] = self.cam_ext_mask_prob |
|
rearranged_images = [] |
|
for i in range(self.img_history_size): |
|
for j in range(self.num_cameras): |
|
images, image_mask = image_metas[j] |
|
image, valid = images[i], image_mask[i] |
|
if valid and (math.prod(image.shape) > 0) and \ |
|
(random.random() > mask_probs[j]): |
|
rearranged_images.append((image, True)) |
|
else: |
|
rearranged_images.append((background_image.copy(), False)) |
|
|
|
preprocessed_images = [] |
|
processor = self.image_processor |
|
for image, valid in rearranged_images: |
|
image = Image.fromarray(image) |
|
if self.image_size is not None: |
|
image = transforms.Resize(self.image_size)(image) |
|
|
|
|
|
if valid and self.auto_adjust_image_brightness: |
|
pixel_values = list(image.getdata()) |
|
average_brightness = sum(sum(pixel) for pixel in pixel_values) / (len(pixel_values) * 255.0 * 3) |
|
if average_brightness <= 0.15: |
|
image = transforms.ColorJitter(brightness=(1.75,1.75))(image) |
|
|
|
|
|
if valid and self.image_aug and (random.random() > 0.5): |
|
aug_type = random.choice([ |
|
"corrput_only", "color_only", "both"]) |
|
if aug_type != "corrput_only": |
|
image = transforms.ColorJitter( |
|
brightness=0.3, contrast=0.4, saturation=0.5, hue=0.03)(image) |
|
if aug_type != "color_only": |
|
image = image_corrupt(image) |
|
|
|
if self.image_aspect_ratio == 'pad': |
|
def expand2square(pil_img, background_color): |
|
width, height = pil_img.size |
|
if width == height: |
|
return pil_img |
|
elif width > height: |
|
result = Image.new(pil_img.mode, (width, width), background_color) |
|
result.paste(pil_img, (0, (width - height) // 2)) |
|
return result |
|
else: |
|
result = Image.new(pil_img.mode, (height, height), background_color) |
|
result.paste(pil_img, ((height - width) // 2, 0)) |
|
return result |
|
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) |
|
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] |
|
preprocessed_images.append(image) |
|
data_dict["images"] = preprocessed_images |
|
|
|
if self.use_precomp_lang_embed: |
|
if content["instruction"][-1] == ".": |
|
content["instruction"] = content["instruction"][:-1] |
|
data_dict["lang_embed"] = torch.load(content["instruction"])['embeddings'][0] \ |
|
if random.random() > self.cond_mask_prob else self.empty_lang_embed |
|
else: |
|
instruction = content["instruction"] \ |
|
if random.random() > self.cond_mask_prob else "" |
|
data_dict["input_ids"] = self.tokenizer( |
|
instruction, |
|
return_tensors="pt", |
|
padding="longest", |
|
truncation=False, |
|
).input_ids[0] |
|
|
|
assert len(data_dict["input_ids"]) <= self.tokenizer_max_length, \ |
|
f"Instruction length {len(data_dict['input_ids'])} exceeds the maximum length {self.tokenizer_max_length}." |
|
|
|
for k, v in data_dict.items(): |
|
if isinstance(v, np.ndarray): |
|
data_dict[k] = torch.from_numpy(v) |
|
|
|
for k, v in data_dict.items(): |
|
assert not isinstance(v, np.ndarray), f"key: {k}, value: {v}" |
|
|
|
|
|
return data_dict |
|
except BaseException as e: |
|
|
|
if data_dict is not None: |
|
print(f"Error catched when processing sample from {data_dict.get('dataset_name')}:", e) |
|
else: |
|
print(f"Error catched when processing sample:", e) |
|
traceback.print_exc() |
|
|
|
index = (index + 1) % len(self) |
|
|
|
|
|
class DataCollatorForVLAConsumerDataset(object): |
|
"""Collate examples for supervised training.""" |
|
|
|
def __init__(self, tokenizer: transformers.PreTrainedTokenizer) -> None: |
|
self.tokenizer = tokenizer |
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
batch = { |
|
"states": [], |
|
"actions": [], |
|
"state_elem_mask": [], |
|
"state_norm": [], |
|
"images": [], |
|
"data_indices": [], |
|
"ctrl_freqs": [] |
|
} |
|
input_ids = [] |
|
lang_embeds = [] |
|
lang_embed_lens = [] |
|
|
|
for instance in instances: |
|
|
|
keys_to_check = [ |
|
'states', 'actions', |
|
'state_elem_mask', 'state_norm', |
|
] |
|
for key in keys_to_check: |
|
if isinstance(instance[key], torch.Tensor): |
|
item = instance[key] |
|
else: |
|
item = torch.from_numpy(instance[key]) |
|
batch[key].append(item) |
|
|
|
if "input_ids" in instance: |
|
input_ids.append(instance["input_ids"]) |
|
else: |
|
lang_embeds.append(instance["lang_embed"]) |
|
lang_embed_lens.append(instance["lang_embed"].shape[0]) |
|
|
|
batch["images"].append(torch.stack(instance["images"], dim=0)) |
|
batch["data_indices"].append(instance["data_idx"]) |
|
batch["ctrl_freqs"].append(instance["ctrl_freq"]) |
|
|
|
keys_to_stack = [ |
|
'states', 'actions', |
|
'state_elem_mask', 'state_norm', |
|
"images" |
|
] |
|
for key in keys_to_stack: |
|
batch[key] = torch.stack(batch[key], dim=0) |
|
|
|
batch["ctrl_freqs"] = torch.tensor(batch["ctrl_freqs"]) |
|
|
|
if len(input_ids) > 0: |
|
input_ids = torch.nn.utils.rnn.pad_sequence( |
|
input_ids, |
|
batch_first=True, |
|
padding_value=self.tokenizer.pad_token_id) |
|
batch["input_ids"] = input_ids |
|
batch["lang_attn_mask"] = input_ids.ne(self.tokenizer.pad_token_id) |
|
else: |
|
lang_embeds = torch.nn.utils.rnn.pad_sequence( |
|
lang_embeds, |
|
batch_first=True, |
|
padding_value=0) |
|
input_lang_attn_mask = torch.zeros( |
|
lang_embeds.shape[0], lang_embeds.shape[1], dtype=torch.bool) |
|
for i, l in enumerate(lang_embed_lens): |
|
input_lang_attn_mask[i, :l] = True |
|
batch["lang_embeds"] = lang_embeds |
|
batch["lang_attn_mask"] = input_lang_attn_mask |
|
|
|
|
|
return batch |
|
|