import time import json import os import time import argparse import sys import signal import random from multiprocessing import Process import numpy as np import tensorflow as tf import yaml from data.vla_dataset import VLADataset from data.filelock import FileLock # Producer does not need GPU tf.config.set_visible_devices([], 'GPU') # Read the config with open('configs/base.yaml', 'r') as file: config = yaml.safe_load(file) # Load some constants from the config BUF_PATH = config['dataset']['buf_path'] BUF_NUM_CHUNKS = config['dataset']['buf_num_chunks'] if BUF_NUM_CHUNKS < 1: raise ValueError("Config `buf_num_chunks` must be at least 1.") BUF_CHUNK_SIZE = config['dataset']['buf_chunk_size'] if BUF_CHUNK_SIZE < 1: raise ValueError("Config `buf_chunk_size` must be at least 1.") def get_dirty_item(chunk_dir): """ Get indexes of dirty items in a chunk. """ dirty_bit = read_dirty_bit(chunk_dir) return np.where(dirty_bit)[0].tolist() def get_clean_item(chunk_dir): """ Get indexes of clean items in a chunk. """ dirty_bit = read_dirty_bit(chunk_dir) return np.where(1 - dirty_bit)[0].tolist() def save_dirty_bit(chunk_dir, dirty_bit): """ Save the dirty bit to the chunk directory. """ time_stmp = time.time() while time.time() - time_stmp < 10.0: try: file_path = os.path.join(chunk_dir, "dirty_bit") lock = FileLock(file_path) lock.acquire_write_lock() with open(file_path, 'wb') as file: file.write(dirty_bit.tobytes()) lock.release_lock() return except KeyboardInterrupt: lock.release_lock() raise KeyboardInterrupt except BaseException: lock.release_lock() continue # raise RuntimeError("Failed to save dirty bit.") print("Failed to save dirty bit.") def read_dirty_bit(chunk_dir): """ Read the dirty bit from the chunk directory. """ # If error occurs, retry time_stmp = time.time() while time.time() - time_stmp < 10.0: try: file_path = os.path.join(chunk_dir, "dirty_bit") lock = FileLock(file_path) lock.acquire_read_lock() with open(file_path, 'rb') as file: dirty_bit = np.frombuffer(file.read(), dtype=np.uint8).copy() lock.release_lock() assert len(dirty_bit) == BUF_CHUNK_SIZE return dirty_bit except KeyboardInterrupt: lock.release_lock() raise KeyboardInterrupt except BaseException: lock.release_lock() continue # If failed to read the dirty bit, return all ones for robustness return np.ones(BUF_CHUNK_SIZE, dtype=np.uint8) def save_sample(step_dict, chunk_dir, chunk_item_idx): """ Save a sample to the chunk directory. """ # Save the json content time_stmp = time.time() while time.time() - time_stmp < 10.0: try: locks = [] json_content = step_dict['json_content'] file_path = os.path.join(chunk_dir, f"json_content_{chunk_item_idx}.json") lock = FileLock(file_path) locks.append(lock) lock.acquire_write_lock() with open(file_path, 'w') as file: json.dump(json_content, file, indent=4) lock.release_lock() # Save all other tensors in a npz file_path = os.path.join(chunk_dir, f"sample_{chunk_item_idx}.npz") lock = FileLock(file_path) locks.append(lock) lock.acquire_write_lock() with open(file_path, 'wb') as file: np.savez( file, step_id=step_dict['step_id'].numpy(), state_chunk=step_dict['state_chunk'].numpy(), state_chunk_time_mask=step_dict['state_chunk_time_mask'].numpy(), action_chunk=step_dict['action_chunk'].numpy(), action_chunk_time_mask=step_dict['action_chunk_time_mask'].numpy(), state_vec_mask=step_dict['state_vec_mask'].numpy(), past_frames_0=step_dict['past_frames_0'].numpy(), past_frames_0_time_mask=step_dict['past_frames_0_time_mask'].numpy(), past_frames_1=step_dict['past_frames_1'].numpy(), past_frames_1_time_mask=step_dict['past_frames_1_time_mask'].numpy(), past_frames_2=step_dict['past_frames_2'].numpy(), past_frames_2_time_mask=step_dict['past_frames_2_time_mask'].numpy(), past_frames_3=step_dict['past_frames_3'].numpy(), past_frames_3_time_mask=step_dict['past_frames_3_time_mask'].numpy(), state_std=step_dict['state_std'].numpy(), state_mean=step_dict['state_mean'].numpy(), state_norm=step_dict['state_norm'].numpy(), ) lock.release_lock() return except KeyboardInterrupt: for lock in locks: lock.release_lock() raise KeyboardInterrupt except BaseException: for lock in locks: lock.release_lock() continue # raise RuntimeError("Failed to save sample.") print("Failed to save sample.") def run_producer(seed, num_workers, worker_id, fill_up, clean_dirty, dataset_type): """ Run the producer. The producer will first fill up the buffer with samples. Then it will keep replacing dirty samples (i.e., samples that have been read by the consumer) with new samples. """ vla_dataset = VLADataset(seed=seed, dataset_type=dataset_type) chunk_start_idx = worker_id * BUF_NUM_CHUNKS // num_workers chunk_end_idx = (worker_id + 1) * BUF_NUM_CHUNKS // num_workers if fill_up: print(f"Worker {worker_id}: Start filling up the buffer...") elif clean_dirty: # Only refresh the dirty bits print(f"Worker {worker_id}: Start refreshing the dirty bits...") for chunk_idx in range(chunk_start_idx, chunk_end_idx): chunk_dir = os.path.join(BUF_PATH, f"chunk_{chunk_idx}") dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8) save_dirty_bit(chunk_dir, dirty_bit) print(f"Worker {worker_id}: Refreshed the dirty bits.") fill_chunk_idx = chunk_start_idx fill_chunk_item_idx = 0 dirty_chunk_idx = chunk_start_idx dirty_chunk_item_idxs = [] time_stmp = time.time() for episode_steps in vla_dataset: for step in episode_steps: if fill_up and fill_chunk_idx < chunk_end_idx: # Fill up the buffer chunk_dir = os.path.join(BUF_PATH, f"chunk_{fill_chunk_idx}") if fill_chunk_item_idx == 0: # Create a new chunk os.makedirs(chunk_dir, exist_ok=True) # Write the dirty bit of size BUF_CHUNK_SIZE dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8) save_dirty_bit(chunk_dir, dirty_bit) # Save the sample save_sample(step, chunk_dir, fill_chunk_item_idx) # print(f"Filled up chunk {fill_chunk_item_idx+1}/{BUF_CHUNK_SIZE} {fill_chunk_idx+1}/{BUF_NUM_CHUNKS}") local_fill_chunk_idx = fill_chunk_idx - chunk_start_idx local_num_chunks = chunk_end_idx - chunk_start_idx if (local_fill_chunk_idx % 10 == 0 or local_fill_chunk_idx == local_num_chunks - 1) and fill_chunk_item_idx == 0: print(f"Worker {worker_id}: Filled up chunk {local_fill_chunk_idx+1}/{local_num_chunks}") fill_chunk_item_idx += 1 if fill_chunk_item_idx == BUF_CHUNK_SIZE: fill_chunk_idx += 1 fill_chunk_item_idx = 0 if fill_chunk_idx == BUF_NUM_CHUNKS: print(f"Worker {worker_id}: Buffer filled up. Start replacing dirty samples...") else: # Search for the dirty chunk to replace while len(dirty_chunk_item_idxs) == 0: dirty_chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}") dirty_chunk_item_idxs = get_dirty_item(dirty_chunk_dir) # Print the dirty ratio if time.time() - time_stmp > 2.0: dirty_ratio = len(dirty_chunk_item_idxs) / BUF_CHUNK_SIZE print(f"Worker {worker_id}: Dirty Ratio for Chunk {dirty_chunk_idx}: {dirty_ratio:.2f}") time_stmp = time.time() if len(dirty_chunk_item_idxs) > 0: # Lock the chunk dirty_bit = np.ones(BUF_CHUNK_SIZE, dtype=np.uint8) save_dirty_bit(dirty_chunk_dir, dirty_bit) # Iterate over the chunks dirty_chunk_idx += 1 if dirty_chunk_idx == chunk_end_idx: dirty_chunk_idx = chunk_start_idx # Replace the dirty item dirty_item_idx = dirty_chunk_item_idxs.pop() chunk_dir = os.path.join(BUF_PATH, f"chunk_{dirty_chunk_idx}") # Save the sample save_sample(step, chunk_dir, dirty_item_idx) # If we have replaced all dirty items in the chunk if len(dirty_chunk_item_idxs) == 0: # Unlock the chunk dirty_bit = np.zeros(BUF_CHUNK_SIZE, dtype=np.uint8) save_dirty_bit(dirty_chunk_dir, dirty_bit) print(f"Worker {worker_id}: Replaced dirty chunk {dirty_chunk_idx}.") if __name__ == '__main__': # Args: n_workers, fill_up parser = argparse.ArgumentParser() parser.add_argument('--n_workers', type=int, default=2, help="Number of parallel workers. It should be less than or equal to the number of chunks.") parser.add_argument('--fill_up', action='store_true', help="Whether to fill up the buffer before replacing dirty samples.") parser.add_argument('--clean_dirty', action='store_true', help="Whether to clean the dirty bits before replacing dirty samples. This option is ignored when `fill_up` is set.") parser.add_argument('--seed', type=int, default=None, help="Random seed. If not set, the seed will be randomly generated.") parser.add_argument('--dataset_type', type=str, default="pretrain", help="Whether to load the pretrain dataset or finetune dataset.") # Run the producer args = parser.parse_args() if args.seed is not None: print(f"Base seed: {args.seed}") random.seed(args.seed) processes = [] process_seeds = [random.randint(0, 2**32) for _ in range(args.n_workers)] print(f"Process seeds: {process_seeds}") def signal_handler(sig, frame): print("Ctrl+C received. Terminating child processes...") for p in processes: p.terminate() sys.exit(0) signal.signal(signal.SIGINT, signal_handler) for worker_id in range(args.n_workers): p = Process(target=run_producer, args=( process_seeds[worker_id], args.n_workers, worker_id, args.fill_up, args.clean_dirty, args.dataset_type)) p.start() processes.append(p) for p in processes: p.join()