import config import torch from torch import nn from pathlib import Path import numpy as np from pyphoon2.DigitalTyphoonDataset import DigitalTyphoonDataset import random import os dataroot = config.DATA_DIR batch_size=config.BATCH_SIZE num_workers=config.NUM_WORKERS split_by=config.SPLIT_BY load_data=config.LOAD_DATA dataset_split=config.DATASET_SPLIT standardize_range=config.STANDARDIZE_RANGE downsample_size=config.DOWNSAMPLE_SIZE type_save=config.TYPE_SAVE data_path = Path(dataroot) images_path = str(data_path / "image") + "/" track_path = str(data_path / "track") + "/" metadata_path = str(data_path / "metadata.json") def image_filter(image): return ( (image.grade() < 7) and (image.year() != 2023) and (100.0 <= image.long() <= 180.0) ) # and (image.mask_1_percent() < self.corruption_ceiling_pct)) def transform_func(image_ray): image_ray = np.clip( image_ray,standardize_range[0],standardize_range[1] ) image_ray = (image_ray - standardize_range[0]) / ( standardize_range[1] - standardize_range[0] ) if downsample_size != (512, 512): image_ray = torch.Tensor(image_ray) image_ray = torch.reshape( image_ray, [1, 1, image_ray.size()[0], image_ray.size()[1]] ) image_ray = nn.functional.interpolate( image_ray, size=downsample_size, mode="bilinear", align_corners=False, ) image_ray = torch.reshape( image_ray, [image_ray.size()[2], image_ray.size()[3]] ) image_ray = image_ray.numpy() return image_ray dataset = DigitalTyphoonDataset( str(images_path), str(track_path), str(metadata_path), "pressure", load_data_into_memory='all_data', filter_func=image_filter, transform_func=transform_func, spectrum="Infrared", verbose=False, ) years = dataset.get_years() old=[] recent=[] now=[] #splitting years in 3 buckets for i in years : if i < 2005 : old.append(i) else : if i < 2015: recent.append(i) else : now.append(i) old_data=[] recent_data=[] now_data=[] #getting the ids from years for year in old : old_data.extend(dataset.get_seq_ids_from_year(year)) for year in recent : recent_data.extend(dataset.get_seq_ids_from_year(year)) for year in now : now_data.extend(dataset.get_seq_ids_from_year(year)) old_train , old_val = [],[] recent_train , recent_val = [],[] now_train , now_val = [],[] #shuffling and splitting 80/20 random.shuffle(old_data) random.shuffle(now_data) random.shuffle(recent_data) l=len(old_data) for i in range(l): if i