|
|
import torch |
|
|
from torch import nn |
|
|
import pytorch_lightning as pl |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from pathlib import Path |
|
|
import numpy as np |
|
|
|
|
|
from pyphoon2.DigitalTyphoonDataset import DigitalTyphoonDataset |
|
|
|
|
|
|
|
|
class TyphoonDataModule(pl.LightningDataModule): |
|
|
def __init__( |
|
|
self, |
|
|
dataroot, |
|
|
batch_size, |
|
|
num_workers, |
|
|
labels = 'grade', |
|
|
split_by="sequence", |
|
|
load_data=False, |
|
|
dataset_split=(0.8, 0.1, 0.1), |
|
|
standardize_range=(150, 350), |
|
|
downsample_size=(224, 224), |
|
|
corruption_ceiling_pct=100, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.batch_size = batch_size |
|
|
self.num_workers = num_workers |
|
|
|
|
|
data_path = Path(dataroot) |
|
|
self.images_path = str(data_path / "image") + "/" |
|
|
self.track_path = str(data_path / "track") + "/" |
|
|
self.metadata_path = str(data_path / "metadata.json") |
|
|
self.load_data = load_data |
|
|
self.split_by = split_by |
|
|
self.labels = labels |
|
|
|
|
|
self.dataset_split = dataset_split |
|
|
self.standardize_range = standardize_range |
|
|
self.downsample_size = downsample_size |
|
|
|
|
|
self.corruption_ceiling_pct = corruption_ceiling_pct |
|
|
|
|
|
def setup(self, stage): |
|
|
|
|
|
dataset = DigitalTyphoonDataset( |
|
|
str(self.images_path), |
|
|
str(self.track_path), |
|
|
str(self.metadata_path), |
|
|
self.labels, |
|
|
load_data_into_memory=self.load_data, |
|
|
filter_func=self.image_filter, |
|
|
transform_func=self.transform_func, |
|
|
spectrum="Infrared", |
|
|
verbose=False, |
|
|
) |
|
|
|
|
|
self.train_set, self.val_set, _ = dataset.random_split( |
|
|
self.dataset_split, split_by=self.split_by |
|
|
) |
|
|
|
|
|
def train_dataloader(self): |
|
|
return DataLoader( |
|
|
self.train_set, |
|
|
batch_size=self.batch_size, |
|
|
num_workers=self.num_workers, |
|
|
shuffle=True, |
|
|
) |
|
|
|
|
|
def val_dataloader(self): |
|
|
return DataLoader( |
|
|
self.val_set, |
|
|
batch_size=self.batch_size, |
|
|
num_workers=self.num_workers, |
|
|
shuffle=False, |
|
|
) |
|
|
|
|
|
def image_filter(self, image): |
|
|
return ( |
|
|
(image.grade() < 6) |
|
|
and (image.grade() > 2) |
|
|
and (image.interpolated() == False) |
|
|
and (image.year() != 2023) |
|
|
and (100.0 <= image.long() <= 180.0) |
|
|
) |
|
|
|
|
|
def transform_func(self, image_ray): |
|
|
image_ray = np.clip( |
|
|
image_ray, self.standardize_range[0], self.standardize_range[1] |
|
|
) |
|
|
image_ray = (image_ray - self.standardize_range[0]) / ( |
|
|
self.standardize_range[1] - self.standardize_range[0] |
|
|
) |
|
|
if self.downsample_size != (512, 512): |
|
|
image_ray = torch.Tensor(image_ray) |
|
|
image_ray = torch.reshape( |
|
|
image_ray, [1, 1, image_ray.size()[0], image_ray.size()[1]] |
|
|
) |
|
|
image_ray = nn.functional.interpolate( |
|
|
image_ray, |
|
|
size=self.downsample_size, |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
image_ray = torch.reshape( |
|
|
image_ray, [image_ray.size()[2], image_ray.size()[3]] |
|
|
) |
|
|
image_ray = image_ray.numpy() |
|
|
return image_ray |
|
|
|