|
from tqdm import tqdm |
|
import numpy as np |
|
import argparse |
|
import torch |
|
import lmdb |
|
import glob |
|
import os |
|
|
|
from utils.lmdb import store_arrays_to_lmdb, process_data_dict |
|
|
|
|
|
def main(): |
|
""" |
|
Aggregate all ode pairs inside a folder into a lmdb dataset. |
|
Each pt file should contain a (key, value) pair representing a |
|
video's ODE trajectories. |
|
""" |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--data_path", type=str, |
|
required=True, help="path to ode pairs") |
|
parser.add_argument("--lmdb_path", type=str, |
|
required=True, help="path to lmdb") |
|
|
|
args = parser.parse_args() |
|
|
|
all_files = sorted(glob.glob(os.path.join(args.data_path, "*.pt"))) |
|
|
|
|
|
total_array_size = 5000000000000 |
|
|
|
env = lmdb.open(args.lmdb_path, map_size=total_array_size * 2) |
|
|
|
counter = 0 |
|
|
|
seen_prompts = set() |
|
|
|
for index, file in tqdm(enumerate(all_files)): |
|
|
|
data_dict = torch.load(file) |
|
|
|
data_dict = process_data_dict(data_dict, seen_prompts) |
|
|
|
|
|
store_arrays_to_lmdb(env, data_dict, start_index=counter) |
|
counter += len(data_dict['prompts']) |
|
|
|
|
|
with env.begin(write=True) as txn: |
|
for key, val in data_dict.items(): |
|
print(key, val) |
|
array_shape = np.array(val.shape) |
|
array_shape[0] = counter |
|
|
|
shape_key = f"{key}_shape".encode() |
|
shape_str = " ".join(map(str, array_shape)) |
|
txn.put(shape_key, shape_str.encode()) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|