jaredhwang's picture
Migrate benchmark from https://github.com/kitamoto-lab/benchmarks/
02cdcbc
raw
history blame
3.43 kB
import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from pathlib import Path
import numpy as np
from pyphoon2.DigitalTyphoonDataset import DigitalTyphoonDataset
class TyphoonDataModule(pl.LightningDataModule):
def __init__(
self,
dataroot,
batch_size,
num_workers,
labels = 'grade',
split_by="sequence",
load_data=False,
dataset_split=(0.8, 0.1, 0.1),
standardize_range=(150, 350),
downsample_size=(224, 224),
corruption_ceiling_pct=100,
):
super().__init__()
self.batch_size = batch_size
self.num_workers = num_workers
data_path = Path(dataroot)
self.images_path = str(data_path / "image") + "/"
self.track_path = str(data_path / "track") + "/"
self.metadata_path = str(data_path / "metadata.json")
self.load_data = load_data
self.split_by = split_by
self.labels = labels
self.dataset_split = dataset_split
self.standardize_range = standardize_range
self.downsample_size = downsample_size
self.corruption_ceiling_pct = corruption_ceiling_pct
def setup(self, stage):
# Load Dataset
dataset = DigitalTyphoonDataset(
str(self.images_path),
str(self.track_path),
str(self.metadata_path),
self.labels,
load_data_into_memory=self.load_data,
filter_func=self.image_filter,
transform_func=self.transform_func,
spectrum="Infrared",
verbose=False,
)
self.train_set, self.val_set, _ = dataset.random_split(
self.dataset_split, split_by=self.split_by
)
def train_dataloader(self):
return DataLoader(
self.train_set,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=True,
)
def val_dataloader(self):
return DataLoader(
self.val_set,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
def image_filter(self, image):
return (
(image.grade() < 6)
and (image.grade() > 2)
and (image.interpolated() == False)
and (image.year() != 2023)
and (100.0 <= image.long() <= 180.0)
) # and (image.mask_1_percent() < self.corruption_ceiling_pct))
def transform_func(self, image_ray):
image_ray = np.clip(
image_ray, self.standardize_range[0], self.standardize_range[1]
)
image_ray = (image_ray - self.standardize_range[0]) / (
self.standardize_range[1] - self.standardize_range[0]
)
if self.downsample_size != (512, 512):
image_ray = torch.Tensor(image_ray)
image_ray = torch.reshape(
image_ray, [1, 1, image_ray.size()[0], image_ray.size()[1]]
)
image_ray = nn.functional.interpolate(
image_ray,
size=self.downsample_size,
mode="bilinear",
align_corners=False,
)
image_ray = torch.reshape(
image_ray, [image_ray.size()[2], image_ray.size()[3]]
)
image_ray = image_ray.numpy()
return image_ray