""" This file will compute the min, max, mean, and standard deviation of each datasets in `pretrain_datasets.json` or `pretrain_datasets.json`. """ import json import argparse import os # from multiprocessing import Pool, Manager import tensorflow as tf import numpy as np from tqdm import tqdm from data.vla_dataset import VLADataset from data.hdf5_vla_dataset import HDF5VLADataset from data.preprocess import generate_json_state # Process each dataset to get the statistics @tf.autograph.experimental.do_not_convert def process_dataset(name_dataset_pair): # print(f"PID {os.getpid()} processing {name_dataset_pair[0]}") dataset_iter = name_dataset_pair[1] MAX_EPISODES = 100000 EPS = 1e-8 # For debugging # MAX_EPISODES = 10 episode_cnt = 0 state_sum = 0 state_sum_sq = 0 z_state_sum = 0 z_state_sum_sq = 0 state_cnt = 0 nz_state_cnt = None state_max = None state_min = None for episode in dataset_iter: episode_cnt += 1 if episode_cnt % 1000 == 0: print(f"Processing episodes {episode_cnt}/{MAX_EPISODES}") if episode_cnt > MAX_EPISODES: break episode_dict = episode['episode_dict'] dataset_name = episode['dataset_name'] res_tup = generate_json_state( episode_dict, dataset_name ) states = res_tup[1] # Convert to numpy states = states.numpy() # Zero the values that are close to zero z_states = states.copy() z_states[np.abs(states) <= EPS] = 0 # Compute the non-zero count if nz_state_cnt is None: nz_state_cnt = np.zeros(states.shape[1]) nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0) # Update statistics state_sum += np.sum(states, axis=0) state_sum_sq += np.sum(states**2, axis=0) z_state_sum += np.sum(z_states, axis=0) z_state_sum_sq += np.sum(z_states**2, axis=0) state_cnt += states.shape[0] if state_max is None: state_max = np.max(states, axis=0) state_min = np.min(states, axis=0) else: state_max = np.maximum(state_max, np.max(states, axis=0)) state_min = np.minimum(state_min, np.min(states, axis=0)) # Add one to avoid division by zero nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt)) result = { "dataset_name": name_dataset_pair[0], "state_mean": (state_sum / state_cnt).tolist(), "state_std": np.sqrt( np.maximum( (z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt), np.zeros_like(state_sum_sq) ) ).tolist(), "state_min": state_min.tolist(), "state_max": state_max.tolist(), } return result def process_hdf5_dataset(vla_dataset): EPS = 1e-8 episode_cnt = 0 state_sum = 0 state_sum_sq = 0 z_state_sum = 0 z_state_sum_sq = 0 state_cnt = 0 nz_state_cnt = None state_max = None state_min = None for i in tqdm(range(len(vla_dataset))): episode = vla_dataset.get_item(i, state_only=True) episode_cnt += 1 states = episode['state'] # Zero the values that are close to zero z_states = states.copy() z_states[np.abs(states) <= EPS] = 0 # Compute the non-zero count if nz_state_cnt is None: nz_state_cnt = np.zeros(states.shape[1]) nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0) # Update statistics state_sum += np.sum(states, axis=0) state_sum_sq += np.sum(states**2, axis=0) z_state_sum += np.sum(z_states, axis=0) z_state_sum_sq += np.sum(z_states**2, axis=0) state_cnt += states.shape[0] if state_max is None: state_max = np.max(states, axis=0) state_min = np.min(states, axis=0) else: state_max = np.maximum(state_max, np.max(states, axis=0)) state_min = np.minimum(state_min, np.min(states, axis=0)) # Add one to avoid division by zero nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt)) result = { "dataset_name": vla_dataset.get_dataset_name(), "state_mean": (state_sum / state_cnt).tolist(), "state_std": np.sqrt( np.maximum( (z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt), np.zeros_like(state_sum_sq) ) ).tolist(), "state_min": state_min.tolist(), "state_max": state_max.tolist(), } return result if __name__ == "__main__": parser = argparse.ArgumentParser() # Multiprocessing currently with bugs # parser.add_argument('--n_workers', type=int, default=1, # help="Number of parallel workers.") parser.add_argument('--dataset_type', type=str, default="pretrain", help="Whether to load the pretrain dataset or finetune dataset.") parser.add_argument('--save_path', type=str, default="configs/dataset_stat.json", help="JSON file path to save the dataset statistics.") parser.add_argument('--skip_exist', action='store_true', help="Whether to skip the existing dataset statistics.") parser.add_argument('--hdf5_dataset', action='store_true', help="Whether to load the dataset from the HDF5 files.") args = parser.parse_args() if args.hdf5_dataset: vla_dataset = HDF5VLADataset() dataset_name = vla_dataset.get_dataset_name() try: with open(args.save_path, 'r') as f: results = json.load(f) except FileNotFoundError: results = {} if args.skip_exist and dataset_name in results: print(f"Skipping existed {dataset_name} dataset statistics") else: print(f"Processing {dataset_name} dataset") result = process_hdf5_dataset(vla_dataset) results[result["dataset_name"]] = result with open(args.save_path, 'w') as f: json.dump(results, f, indent=4) print("All datasets have been processed.") os._exit(0) vla_dataset = VLADataset( seed=0, dataset_type=args.dataset_type, repeat=False) name_dataset_pairs = vla_dataset.name2dataset.items() # num_workers = args.n_workers for name_dataset_pair in tqdm(name_dataset_pairs): try: with open(args.save_path, 'r') as f: results = json.load(f) except FileNotFoundError: results = {} if args.skip_exist and name_dataset_pair[0] in results: print(f"Skipping existed {name_dataset_pair[0]} dataset statistics") continue print(f"Processing {name_dataset_pair[0]} dataset") result = process_dataset(name_dataset_pair) results[result["dataset_name"]] = result # Save the results in the json file after each dataset (for resume) with open(args.save_path, 'w') as f: json.dump(results, f, indent=4) print("All datasets have been processed.") # with Manager() as manager: # # Create shared dictionary and lock through the manager, accessible by all processes # progress = manager.dict(processed=0, results={}) # progress_lock = manager.Lock() # # Callback function to update progress # def update_progress(result): # with progress_lock: # progress['processed'] += 1 # print(f"{result['dataset_name']} - {progress['processed']}/{len(name_dataset_pairs)} datasets have been processed") # # Append the result to the shared dictionary # progress['results'][result["dataset_name"]] = result # with Pool(num_workers) as p: # for name_dataset_pair in name_dataset_pairs: # p.apply_async(process_dataset, args=(name_dataset_pair,), callback=update_progress) # # Close the pool and wait for the work to finish # p.close() # p.join() # # Save the results in the json file # with open(args.save_path, 'w') as f: # json.dump(progress['results'], f, indent=4)