from .utils.transforms import * from .base.batched_sampler import BatchedRandomSampler # noqa from .arkitscenes import ARKitScenes_Multi # noqa from .arkitscenes_highres import ARKitScenesHighRes_Multi from .bedlam import BEDLAM_Multi from .blendedmvs import BlendedMVS_Multi # noqa from .co3d import Co3d_Multi # noqa from .cop3d import Cop3D_Multi from .dl3dv import DL3DV_Multi from .dynamic_replica import DynamicReplica from .eden import EDEN_Multi from .hypersim import HyperSim_Multi from .hoi4d import HOI4D_Multi from .irs import IRS from .mapfree import MapFree_Multi from .megadepth import MegaDepth_Multi # noqa from .mp3d import MP3D_Multi from .mvimgnet import MVImgNet_Multi from .mvs_synth import MVS_Synth_Multi from .omniobject3d import OmniObject3D_Multi from .pointodyssey import PointOdyssey_Multi from .realestate10k import RE10K_Multi from .scannet import ScanNet_Multi from .scannetpp import ScanNetpp_Multi # noqa from .smartportraits import SmartPortraits_Multi from .spring import Spring from .synscapes import SynScapes from .tartanair import TartanAir_Multi from .threedkb import ThreeDKenBurns from .uasol import UASOL_Multi from .urbansyn import UrbanSyn from .unreal4k import UnReal4K_Multi from .vkitti2 import VirtualKITTI2_Multi # noqa from .waymo import Waymo_Multi # noqa from .wildrgbd import WildRGBD_Multi # noqa from accelerate import Accelerator def get_data_loader( dataset, batch_size, num_workers=8, shuffle=True, drop_last=True, pin_mem=True, accelerator: Accelerator = None, fixed_length=False, ): import torch # pytorch dataset if isinstance(dataset, str): dataset = eval(dataset) try: sampler = dataset.make_sampler( batch_size, shuffle=shuffle, drop_last=drop_last, world_size=accelerator.num_processes, fixed_length=fixed_length, ) shuffle = False data_loader = torch.utils.data.DataLoader( dataset, batch_sampler=sampler, num_workers=num_workers, pin_memory=pin_mem, ) except (AttributeError, NotImplementedError): sampler = None data_loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_mem, drop_last=drop_last, ) return data_loader