Commit
·
02cdcbc
1
Parent(s):
f3bf6d6
Migrate benchmark from https://github.com/kitamoto-lab/benchmarks/
Browse files- Dockerfile +19 -0
- FrameDatamodule.py +110 -0
- README.md +38 -0
- config.py +28 -0
- createdataset.py +178 -0
- lightning_resnetReg.py +149 -0
- loading.py +43 -0
- split_testing.py +168 -0
- train_split.py +138 -0
Dockerfile
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM ubuntu
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && \
|
| 6 |
+
apt-get upgrade -y && \
|
| 7 |
+
apt-get install -y git && \
|
| 8 |
+
apt-get install -y libopenmpi-dev && \
|
| 9 |
+
apt-get install -y python3-pip && \
|
| 10 |
+
git clone https://github.com/kitamoto-lab/pyphoon2.git && \
|
| 11 |
+
cd pyphoon2 && \
|
| 12 |
+
pip3 install . && \
|
| 13 |
+
pip3 install tqdm && \
|
| 14 |
+
pip3 install scikit-learn && \
|
| 15 |
+
pip3 install matplotlib && \
|
| 16 |
+
pip3 install seaborn && \
|
| 17 |
+
pip3 install lightning && \
|
| 18 |
+
pip3 install tensorboardX
|
| 19 |
+
|
FrameDatamodule.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import pytorch_lightning as pl
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from pyphoon2.DigitalTyphoonDataset import DigitalTyphoonDataset
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TyphoonDataModule(pl.LightningDataModule):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
dataroot,
|
| 16 |
+
batch_size,
|
| 17 |
+
num_workers,
|
| 18 |
+
labels = 'grade',
|
| 19 |
+
split_by="sequence",
|
| 20 |
+
load_data=False,
|
| 21 |
+
dataset_split=(0.8, 0.1, 0.1),
|
| 22 |
+
standardize_range=(150, 350),
|
| 23 |
+
downsample_size=(224, 224),
|
| 24 |
+
corruption_ceiling_pct=100,
|
| 25 |
+
):
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
self.batch_size = batch_size
|
| 29 |
+
self.num_workers = num_workers
|
| 30 |
+
|
| 31 |
+
data_path = Path(dataroot)
|
| 32 |
+
self.images_path = str(data_path / "image") + "/"
|
| 33 |
+
self.track_path = str(data_path / "track") + "/"
|
| 34 |
+
self.metadata_path = str(data_path / "metadata.json")
|
| 35 |
+
self.load_data = load_data
|
| 36 |
+
self.split_by = split_by
|
| 37 |
+
self.labels = labels
|
| 38 |
+
|
| 39 |
+
self.dataset_split = dataset_split
|
| 40 |
+
self.standardize_range = standardize_range
|
| 41 |
+
self.downsample_size = downsample_size
|
| 42 |
+
|
| 43 |
+
self.corruption_ceiling_pct = corruption_ceiling_pct
|
| 44 |
+
|
| 45 |
+
def setup(self, stage):
|
| 46 |
+
# Load Dataset
|
| 47 |
+
dataset = DigitalTyphoonDataset(
|
| 48 |
+
str(self.images_path),
|
| 49 |
+
str(self.track_path),
|
| 50 |
+
str(self.metadata_path),
|
| 51 |
+
self.labels,
|
| 52 |
+
load_data_into_memory=self.load_data,
|
| 53 |
+
filter_func=self.image_filter,
|
| 54 |
+
transform_func=self.transform_func,
|
| 55 |
+
spectrum="Infrared",
|
| 56 |
+
verbose=False,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
self.train_set, self.val_set, _ = dataset.random_split(
|
| 60 |
+
self.dataset_split, split_by=self.split_by
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def train_dataloader(self):
|
| 64 |
+
return DataLoader(
|
| 65 |
+
self.train_set,
|
| 66 |
+
batch_size=self.batch_size,
|
| 67 |
+
num_workers=self.num_workers,
|
| 68 |
+
shuffle=True,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def val_dataloader(self):
|
| 72 |
+
return DataLoader(
|
| 73 |
+
self.val_set,
|
| 74 |
+
batch_size=self.batch_size,
|
| 75 |
+
num_workers=self.num_workers,
|
| 76 |
+
shuffle=False,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
def image_filter(self, image):
|
| 80 |
+
return (
|
| 81 |
+
(image.grade() < 6)
|
| 82 |
+
and (image.grade() > 2)
|
| 83 |
+
and (image.interpolated() == False)
|
| 84 |
+
and (image.year() != 2023)
|
| 85 |
+
and (100.0 <= image.long() <= 180.0)
|
| 86 |
+
) # and (image.mask_1_percent() < self.corruption_ceiling_pct))
|
| 87 |
+
|
| 88 |
+
def transform_func(self, image_ray):
|
| 89 |
+
image_ray = np.clip(
|
| 90 |
+
image_ray, self.standardize_range[0], self.standardize_range[1]
|
| 91 |
+
)
|
| 92 |
+
image_ray = (image_ray - self.standardize_range[0]) / (
|
| 93 |
+
self.standardize_range[1] - self.standardize_range[0]
|
| 94 |
+
)
|
| 95 |
+
if self.downsample_size != (512, 512):
|
| 96 |
+
image_ray = torch.Tensor(image_ray)
|
| 97 |
+
image_ray = torch.reshape(
|
| 98 |
+
image_ray, [1, 1, image_ray.size()[0], image_ray.size()[1]]
|
| 99 |
+
)
|
| 100 |
+
image_ray = nn.functional.interpolate(
|
| 101 |
+
image_ray,
|
| 102 |
+
size=self.downsample_size,
|
| 103 |
+
mode="bilinear",
|
| 104 |
+
align_corners=False,
|
| 105 |
+
)
|
| 106 |
+
image_ray = torch.reshape(
|
| 107 |
+
image_ray, [image_ray.size()[2], image_ray.size()[3]]
|
| 108 |
+
)
|
| 109 |
+
image_ray = image_ray.numpy()
|
| 110 |
+
return image_ray
|
README.md
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
## Instructions to run
|
| 3 |
+
|
| 4 |
+
#### Docker
|
| 5 |
+
All of the below commands should be run in a Docker container built using the Dockerfile in the repo, with the data and repo being exposed as volumes in the container.
|
| 6 |
+
|
| 7 |
+
To build:
|
| 8 |
+
|
| 9 |
+
```docker build -t benchmarks_img .```
|
| 10 |
+
|
| 11 |
+
To run an interactive shell:
|
| 12 |
+
|
| 13 |
+
```docker run -it --shm-size=2G --gpus all -v /path/to/neurips2023-benchmarks:/neurips2023-benchmarks -v /path/to/datasets/:/data benchmarks_img```
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
### Reanalysis Task
|
| 17 |
+
Every command should be run in the reanalysis folder. The path to this folder and to the data should be provided in the config.py file.
|
| 18 |
+
|
| 19 |
+
#### Create buckets
|
| 20 |
+
First, you have to split and save the dataset into 3 buckets according to the type of splitting refered in the config.py file ('standard' for standard splitting between before 2005 / between 2005 and 2015 / after 2015, 'same_size' for the same splitting but with a equal number of sequences per bucket).
|
| 21 |
+
```
|
| 22 |
+
python3 createdataset.py
|
| 23 |
+
```
|
| 24 |
+
This will create a folder (named 'save' or 'save_same') with 6 .txt file containing the id of the sequences used for training and testing in each bucket.
|
| 25 |
+
|
| 26 |
+
#### Train
|
| 27 |
+
You can now train for a number of runs (called version in the logs) and epochs specified in the config.py file.
|
| 28 |
+
```
|
| 29 |
+
python3 train_split.py
|
| 30 |
+
```
|
| 31 |
+
A tensorboard log while be created for each run with each bucket in the tb_logs.
|
| 32 |
+
|
| 33 |
+
#### Test
|
| 34 |
+
After specifing a list of versions in the config.py file, you'll be able to test the model.
|
| 35 |
+
```
|
| 36 |
+
python3 split_testing.py
|
| 37 |
+
```
|
| 38 |
+
The accuracy (RMSE in hPa) will be displayed on the terminal but also written in a log.txt file in the directory ```reanalysis```.
|
config.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
# Training Hyperparameters
|
| 4 |
+
LEARNING_RATE = 0.0001
|
| 5 |
+
BATCH_SIZE = 16
|
| 6 |
+
NUM_WORKERS = 16
|
| 7 |
+
MAX_EPOCHS = 101
|
| 8 |
+
NB_RUNS = 5
|
| 9 |
+
TESTING_VERSION = (0,1,2,3,4)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# DATASET
|
| 14 |
+
WEIGHTS = None
|
| 15 |
+
LABELS = 'pressure'
|
| 16 |
+
SPLIT_BY = 'sequence'
|
| 17 |
+
LOAD_DATA = 'all_data'
|
| 18 |
+
DATASET_SPLIT = (0.8, 0.1, 0.1)
|
| 19 |
+
STANDARDIZE_RANGE = (170, 350)
|
| 20 |
+
DOWNSAMPLE_SIZE = (224, 224)
|
| 21 |
+
NUM_CLASSES = 1
|
| 22 |
+
TYPE_SAVE = 'standard' #'standard' or 'same_size'
|
| 23 |
+
|
| 24 |
+
# Computation
|
| 25 |
+
ACCELERATOR = 'gpu' if torch.cuda.is_available() else 'cpu'
|
| 26 |
+
DEVICE = [0]
|
| 27 |
+
DATA_DIR = '/app/datasets/wnp/'
|
| 28 |
+
LOG_DIR = "/app/pyphoon2/reanalysis/tb_logs"
|
createdataset.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import config
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import numpy as np
|
| 6 |
+
from pyphoon2.DigitalTyphoonDataset import DigitalTyphoonDataset
|
| 7 |
+
import random
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
dataroot = config.DATA_DIR
|
| 11 |
+
batch_size=config.BATCH_SIZE
|
| 12 |
+
num_workers=config.NUM_WORKERS
|
| 13 |
+
split_by=config.SPLIT_BY
|
| 14 |
+
load_data=config.LOAD_DATA
|
| 15 |
+
dataset_split=config.DATASET_SPLIT
|
| 16 |
+
standardize_range=config.STANDARDIZE_RANGE
|
| 17 |
+
downsample_size=config.DOWNSAMPLE_SIZE
|
| 18 |
+
type_save=config.TYPE_SAVE
|
| 19 |
+
|
| 20 |
+
data_path = Path(dataroot)
|
| 21 |
+
images_path = str(data_path / "image") + "/"
|
| 22 |
+
track_path = str(data_path / "track") + "/"
|
| 23 |
+
metadata_path = str(data_path / "metadata.json")
|
| 24 |
+
|
| 25 |
+
def image_filter(image):
|
| 26 |
+
return (
|
| 27 |
+
(image.grade() < 7)
|
| 28 |
+
and (image.year() != 2023)
|
| 29 |
+
and (100.0 <= image.long() <= 180.0)
|
| 30 |
+
) # and (image.mask_1_percent() < self.corruption_ceiling_pct))
|
| 31 |
+
|
| 32 |
+
def transform_func(image_ray):
|
| 33 |
+
image_ray = np.clip(
|
| 34 |
+
image_ray,standardize_range[0],standardize_range[1]
|
| 35 |
+
)
|
| 36 |
+
image_ray = (image_ray - standardize_range[0]) / (
|
| 37 |
+
standardize_range[1] - standardize_range[0]
|
| 38 |
+
)
|
| 39 |
+
if downsample_size != (512, 512):
|
| 40 |
+
image_ray = torch.Tensor(image_ray)
|
| 41 |
+
image_ray = torch.reshape(
|
| 42 |
+
image_ray, [1, 1, image_ray.size()[0], image_ray.size()[1]]
|
| 43 |
+
)
|
| 44 |
+
image_ray = nn.functional.interpolate(
|
| 45 |
+
image_ray,
|
| 46 |
+
size=downsample_size,
|
| 47 |
+
mode="bilinear",
|
| 48 |
+
align_corners=False,
|
| 49 |
+
)
|
| 50 |
+
image_ray = torch.reshape(
|
| 51 |
+
image_ray, [image_ray.size()[2], image_ray.size()[3]]
|
| 52 |
+
)
|
| 53 |
+
image_ray = image_ray.numpy()
|
| 54 |
+
return image_ray
|
| 55 |
+
|
| 56 |
+
dataset = DigitalTyphoonDataset(
|
| 57 |
+
str(images_path),
|
| 58 |
+
str(track_path),
|
| 59 |
+
str(metadata_path),
|
| 60 |
+
"pressure",
|
| 61 |
+
load_data_into_memory='all_data',
|
| 62 |
+
filter_func=image_filter,
|
| 63 |
+
transform_func=transform_func,
|
| 64 |
+
spectrum="Infrared",
|
| 65 |
+
verbose=False,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
years = dataset.get_years()
|
| 70 |
+
old=[]
|
| 71 |
+
recent=[]
|
| 72 |
+
now=[]
|
| 73 |
+
|
| 74 |
+
#splitting years in 3 buckets
|
| 75 |
+
for i in years :
|
| 76 |
+
if i < 2005 :
|
| 77 |
+
old.append(i)
|
| 78 |
+
else :
|
| 79 |
+
if i < 2015:
|
| 80 |
+
recent.append(i)
|
| 81 |
+
else :
|
| 82 |
+
now.append(i)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
old_data=[]
|
| 86 |
+
recent_data=[]
|
| 87 |
+
now_data=[]
|
| 88 |
+
|
| 89 |
+
#getting the ids from years
|
| 90 |
+
for year in old :
|
| 91 |
+
old_data.extend(dataset.get_seq_ids_from_year(year))
|
| 92 |
+
|
| 93 |
+
for year in recent :
|
| 94 |
+
recent_data.extend(dataset.get_seq_ids_from_year(year))
|
| 95 |
+
|
| 96 |
+
for year in now :
|
| 97 |
+
now_data.extend(dataset.get_seq_ids_from_year(year))
|
| 98 |
+
|
| 99 |
+
old_train , old_val = [],[]
|
| 100 |
+
recent_train , recent_val = [],[]
|
| 101 |
+
now_train , now_val = [],[]
|
| 102 |
+
|
| 103 |
+
#shuffling and splitting 80/20
|
| 104 |
+
random.shuffle(old_data)
|
| 105 |
+
random.shuffle(now_data)
|
| 106 |
+
random.shuffle(recent_data)
|
| 107 |
+
|
| 108 |
+
l=len(old_data)
|
| 109 |
+
for i in range(l):
|
| 110 |
+
if i<l*0.8:
|
| 111 |
+
old_train.append(old_data[i])
|
| 112 |
+
else:
|
| 113 |
+
old_val.append(old_data[i])
|
| 114 |
+
|
| 115 |
+
l=len(recent_data)
|
| 116 |
+
for i in range(l):
|
| 117 |
+
if i<l*0.8:
|
| 118 |
+
recent_train.append(recent_data[i])
|
| 119 |
+
else:
|
| 120 |
+
recent_val.append(recent_data[i])
|
| 121 |
+
|
| 122 |
+
l=len(now_data)
|
| 123 |
+
for i in range(l):
|
| 124 |
+
if i<l*0.8:
|
| 125 |
+
now_train.append(now_data[i])
|
| 126 |
+
else:
|
| 127 |
+
now_val.append(now_data[i])
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
#writting in file depending on which format
|
| 132 |
+
if(type_save=="standard"):
|
| 133 |
+
if not(os.path.exists('./save')): os.mkdir('./save')
|
| 134 |
+
with open('save/old_train.txt','w+') as file:
|
| 135 |
+
for id in old_train:
|
| 136 |
+
file.write(id+"\n")
|
| 137 |
+
|
| 138 |
+
with open('save/old_val.txt','w+') as file:
|
| 139 |
+
for id in old_val :
|
| 140 |
+
file.write(id+"\n")
|
| 141 |
+
|
| 142 |
+
with open('save/recent_train.txt','w+') as file:
|
| 143 |
+
for id in recent_train:
|
| 144 |
+
file.write(id+"\n")
|
| 145 |
+
|
| 146 |
+
with open('save/recent_val.txt','w+') as file:
|
| 147 |
+
for id in recent_val:
|
| 148 |
+
file.write(id+"\n")
|
| 149 |
+
|
| 150 |
+
with open('save/now_train.txt','w+') as file:
|
| 151 |
+
for id in now_train:
|
| 152 |
+
file.write(id+"\n")
|
| 153 |
+
|
| 154 |
+
with open('save/now_val.txt','w+') as file:
|
| 155 |
+
for id in now_val:
|
| 156 |
+
file.write(id+"\n")
|
| 157 |
+
|
| 158 |
+
if(type_save=="same_size"):
|
| 159 |
+
if not(os.path.exists('./save_same')): os.mkdir('./save_same')
|
| 160 |
+
with(
|
| 161 |
+
open('save_same/old_train.txt','w+') as train1,
|
| 162 |
+
open('save_same/old_val.txt','w+') as test1,
|
| 163 |
+
open('save_same/recent_train.txt','w+') as train2,
|
| 164 |
+
open('save_same/recent_val.txt','w+') as test2,
|
| 165 |
+
open('save_same/now_train.txt','w+') as train3,
|
| 166 |
+
open('save_same/now_val.txt','w+') as test3,
|
| 167 |
+
):
|
| 168 |
+
for i in range(min(len(old_train),len(recent_train),len(now_train))):
|
| 169 |
+
train1.write(old_train[i]+'\n')
|
| 170 |
+
train2.write(recent_train[i]+'\n')
|
| 171 |
+
train3.write(now_train[i]+'\n')
|
| 172 |
+
for i in range(min(len(old_val),len(recent_val),len(now_val))):
|
| 173 |
+
test1.write(old_val[i]+'\n')
|
| 174 |
+
test2.write(recent_val[i]+'\n')
|
| 175 |
+
test3.write(now_val[i]+'\n')
|
| 176 |
+
|
| 177 |
+
print("Saving Done !")
|
| 178 |
+
|
lightning_resnetReg.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
from torchvision.models import resnet18
|
| 5 |
+
import pytorch_lightning as pl
|
| 6 |
+
from torchmetrics import MeanSquaredError
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LightningResnetReg(pl.LightningModule):
|
| 11 |
+
def __init__(self, learning_rate, weights, num_classes):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.save_hyperparameters()
|
| 14 |
+
|
| 15 |
+
self.model = resnet18(num_classes=1, weights=weights)
|
| 16 |
+
self.model.conv1 = nn.Conv2d(
|
| 17 |
+
1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
|
| 18 |
+
)
|
| 19 |
+
self.model.fc = nn.Linear(in_features=512, out_features=1, bias=True)
|
| 20 |
+
|
| 21 |
+
self.learning_rate = learning_rate
|
| 22 |
+
self.loss_fn = nn.MSELoss()
|
| 23 |
+
self.accuracy = MeanSquaredError(squared = False)
|
| 24 |
+
self.compt = 1
|
| 25 |
+
self.predicted_labels = []
|
| 26 |
+
self.truth_labels = []
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def forward(self, images):
|
| 30 |
+
images = torch.Tensor(images).float()
|
| 31 |
+
images = torch.reshape(
|
| 32 |
+
images, [images.size()[0], 1, images.size()[1], images.size()[2]]
|
| 33 |
+
)
|
| 34 |
+
output = self.model(images)
|
| 35 |
+
return output
|
| 36 |
+
|
| 37 |
+
def training_step(self, batch, batch_idx):
|
| 38 |
+
loss, outputs, labels = self._common_step(batch)
|
| 39 |
+
accuracy = self.accuracy(outputs, labels)
|
| 40 |
+
self.log_dict({
|
| 41 |
+
"train_loss": loss,
|
| 42 |
+
"train_RMSE": accuracy
|
| 43 |
+
},
|
| 44 |
+
on_step=False,
|
| 45 |
+
on_epoch=True,
|
| 46 |
+
sync_dist=True,
|
| 47 |
+
)
|
| 48 |
+
return loss
|
| 49 |
+
|
| 50 |
+
def validation_step(self, batch, batch_idx):
|
| 51 |
+
loss, outputs, labels = self._common_step(batch)
|
| 52 |
+
self.log("validation_loss", loss,
|
| 53 |
+
on_step=False, on_epoch=True, sync_dist=True)
|
| 54 |
+
self.predicted_labels.append(outputs)
|
| 55 |
+
self.truth_labels.append(labels.float())
|
| 56 |
+
return loss
|
| 57 |
+
|
| 58 |
+
def test_step(self, batch, batch_idx):
|
| 59 |
+
loss, outputs, labels = self._common_step(batch)
|
| 60 |
+
self.log("test_loss", loss,
|
| 61 |
+
on_step=False, on_epoch=True, sync_dist=True)
|
| 62 |
+
self.predicted_labels.append(outputs)
|
| 63 |
+
self.truth_labels.append(labels.float())
|
| 64 |
+
return loss
|
| 65 |
+
|
| 66 |
+
def _common_step(self, batch):
|
| 67 |
+
images, labels = batch
|
| 68 |
+
labels = labels - 2
|
| 69 |
+
labels = torch.reshape(labels, [labels.size()[0],1])
|
| 70 |
+
outputs = self.forward(images)
|
| 71 |
+
loss = self.loss_fn(outputs, labels.float())
|
| 72 |
+
return loss, outputs, labels
|
| 73 |
+
|
| 74 |
+
def predict_step(self, batch):
|
| 75 |
+
images, labels = batch
|
| 76 |
+
labels = labels - 2
|
| 77 |
+
labels = torch.reshape(labels, [labels.size()[0],1])
|
| 78 |
+
outputs = self.forward(images)
|
| 79 |
+
preds = outputs
|
| 80 |
+
return preds
|
| 81 |
+
|
| 82 |
+
def configure_optimizers(self):
|
| 83 |
+
return optim.SGD(self.parameters(), lr=self.learning_rate)
|
| 84 |
+
|
| 85 |
+
def on_validation_epoch_end(self):
|
| 86 |
+
|
| 87 |
+
tensorboard = self.logger.experiment
|
| 88 |
+
all_preds = torch.concat(self.predicted_labels)
|
| 89 |
+
all_truths = torch.concat(self.truth_labels)
|
| 90 |
+
all_couple = torch.cat((all_truths, all_preds), dim=1)
|
| 91 |
+
wind_values = torch.unique(all_truths)
|
| 92 |
+
pred_means = []
|
| 93 |
+
pred_std = []
|
| 94 |
+
pred_n = []
|
| 95 |
+
for value in wind_values:
|
| 96 |
+
# find all the couple (truth, preds) where truth == value and compute the mean of all the prediction for this value
|
| 97 |
+
m = torch.mean((all_couple[torch.where(all_couple[:,0] == value)][:,1].float()))
|
| 98 |
+
std = torch.std((all_couple[torch.where(all_couple[:,0] == value)][:,1].float()))
|
| 99 |
+
n = len(all_couple[torch.where(all_couple[:,0] == value)][:,1].float())
|
| 100 |
+
pred_means.append(m)
|
| 101 |
+
pred_std.append(std)
|
| 102 |
+
pred_n.append(n)
|
| 103 |
+
|
| 104 |
+
# Log regression line graph every 5 epochs
|
| 105 |
+
if(self.current_epoch %5 == 0 ):
|
| 106 |
+
for i in range(len(wind_values)):
|
| 107 |
+
tensorboard.add_scalars(f"epoch_{self.current_epoch}",{'pred_mean':pred_means[i],'truth':wind_values[i]},wind_values[i])
|
| 108 |
+
tensorboard.add_scalars(f"epoch_{self.current_epoch}_stats",{'pred_std':pred_std[i],'pred_n':pred_n[i]},wind_values[i])
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
self.log("validation_RMSE", self.accuracy(all_preds,all_truths),
|
| 112 |
+
on_step=False, on_epoch=True, sync_dist=True)
|
| 113 |
+
self.predicted_labels.clear() # free memory
|
| 114 |
+
self.truth_labels.clear()
|
| 115 |
+
|
| 116 |
+
def on_test_epoch_end(self):
|
| 117 |
+
tensorboard= self.logger.experiment
|
| 118 |
+
|
| 119 |
+
all_preds = torch.concat(self.predicted_labels)
|
| 120 |
+
all_truths = torch.concat(self.truth_labels)
|
| 121 |
+
all_couple = torch.cat((all_truths, all_preds), dim=1)
|
| 122 |
+
self.logger.experiment.add_embedding(all_couple, tag="couple_label_pred_ep" + str(self.compt) + ".tsv")
|
| 123 |
+
unique_values = torch.unique(all_truths)
|
| 124 |
+
pred_means = []
|
| 125 |
+
pred_std = []
|
| 126 |
+
pred_n = []
|
| 127 |
+
for value in unique_values:
|
| 128 |
+
# find all the couple (truth, preds) where truth == value and compute the mean of all the prediction for this value
|
| 129 |
+
m = torch.mean((all_couple[torch.where(all_couple[:,0] == value)][:,1].float()))
|
| 130 |
+
std = torch.std((all_couple[torch.where(all_couple[:,0] == value)][:,1].float()))
|
| 131 |
+
n = len(all_couple[torch.where(all_couple[:,0] == value)][:,1].float())
|
| 132 |
+
pred_means.append(m)
|
| 133 |
+
pred_std.append(std)
|
| 134 |
+
pred_n.append(n)
|
| 135 |
+
|
| 136 |
+
# Log regression line graph every 5 epochs
|
| 137 |
+
if(self.current_epoch %5 == 0 ):
|
| 138 |
+
for i in range(len(unique_values)):
|
| 139 |
+
tensorboard.add_scalars(f"test_{self.compt}",{'pred_mean':pred_means[i],'truth':unique_values[i]},unique_values[i])
|
| 140 |
+
tensorboard.add_scalars(f"test_{self.compt}_stats",{'pred_std':pred_std[i],'pred_n':pred_n[i]},unique_values[i])
|
| 141 |
+
|
| 142 |
+
Accuracy = self.accuracy(all_preds,all_truths)
|
| 143 |
+
self.log(f"test_{self.compt}_RMSE", Accuracy,
|
| 144 |
+
on_step=False, on_epoch=True, sync_dist=True)
|
| 145 |
+
with open("log.txt","a+") as file:
|
| 146 |
+
file.write(f"test_{self.compt}_RMSE : {Accuracy} \n")
|
| 147 |
+
self.predicted_labels.clear() # free memory
|
| 148 |
+
self.truth_labels.clear()
|
| 149 |
+
self.compt +=1
|
loading.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import DataLoader
|
| 2 |
+
|
| 3 |
+
def load(type,dataset,batch_size,num_workers,type_save='standard'):
|
| 4 |
+
train, test = [],[]
|
| 5 |
+
if (type_save=='standard') :
|
| 6 |
+
file_dir = 'save/'
|
| 7 |
+
if (type_save=='same_size') :
|
| 8 |
+
file_dir = 'save_same/'
|
| 9 |
+
|
| 10 |
+
if type==0 :
|
| 11 |
+
with open(file_dir + 'old_train.txt','r') as file:
|
| 12 |
+
train_id=[line for line in file]
|
| 13 |
+
with open(file_dir + 'old_val.txt','r') as file:
|
| 14 |
+
test_id =[line for line in file]
|
| 15 |
+
if type==1 :
|
| 16 |
+
with open(file_dir + 'recent_train.txt','r') as file:
|
| 17 |
+
train_id=[line for line in file]
|
| 18 |
+
with open(file_dir + 'recent_val.txt','r') as file:
|
| 19 |
+
test_id =[line for line in file]
|
| 20 |
+
if type==2 :
|
| 21 |
+
with open(file_dir + 'now_train.txt','r') as file:
|
| 22 |
+
train_id=[line for line in file]
|
| 23 |
+
with open(file_dir + 'now_val.txt','r') as file:
|
| 24 |
+
test_id =[line for line in file]
|
| 25 |
+
if type==3 :
|
| 26 |
+
with open(file_dir + 'now_train.txt','r') as file:
|
| 27 |
+
train_id1=[line for line in file]
|
| 28 |
+
with open(file_dir + 'now_val.txt','r') as file:
|
| 29 |
+
test_id1 =[line for line in file]
|
| 30 |
+
with open(file_dir + 'recent_train.txt','r') as file:
|
| 31 |
+
train_id2=[line for line in file]
|
| 32 |
+
with open(file_dir + 'recent_val.txt','r') as file:
|
| 33 |
+
test_id2 =[line for line in file]
|
| 34 |
+
train_id = train_id1 +train_id2
|
| 35 |
+
test_id = test_id1+ test_id2
|
| 36 |
+
|
| 37 |
+
train_id = [x.replace('\n', '') for x in train_id]
|
| 38 |
+
test_id = [x.replace('\n','') for x in test_id]
|
| 39 |
+
train = DataLoader(dataset.images_from_sequences(train_id),batch_size= batch_size,num_workers=num_workers,shuffle=True)
|
| 40 |
+
test = DataLoader(dataset.images_from_sequences(test_id),batch_size= batch_size,num_workers=num_workers,shuffle=False)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
return train, test
|
split_testing.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytorch_lightning as pl
|
| 2 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
| 3 |
+
from lightning_resnetReg import LightningResnetReg
|
| 4 |
+
import config
|
| 5 |
+
import loading
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
from pyphoon2.DigitalTyphoonDataset import DigitalTyphoonDataset
|
| 14 |
+
|
| 15 |
+
def main():
|
| 16 |
+
logger_old = TensorBoardLogger("tb_logs", name="resnet_test_old_same")
|
| 17 |
+
logger_recent = TensorBoardLogger("tb_logs", name="resnet_test_recent_same")
|
| 18 |
+
logger_now = TensorBoardLogger("tb_logs", name="resnet_test_now_same")
|
| 19 |
+
|
| 20 |
+
# Set up data
|
| 21 |
+
data_root = config.DATA_DIR
|
| 22 |
+
batch_size=config.BATCH_SIZE
|
| 23 |
+
num_workers=config.NUM_WORKERS
|
| 24 |
+
standardize_range=config.STANDARDIZE_RANGE
|
| 25 |
+
downsample_size=config.DOWNSAMPLE_SIZE
|
| 26 |
+
type_save = config.TYPE_SAVE
|
| 27 |
+
versions = config.TESTING_VERSION
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
data_path = Path(data_root)
|
| 31 |
+
images_path = str(data_path / "image") + "/"
|
| 32 |
+
track_path = str(data_path / "track") + "/"
|
| 33 |
+
metadata_path = str(data_path / "metadata.json")
|
| 34 |
+
|
| 35 |
+
def image_filter(image):
|
| 36 |
+
return (
|
| 37 |
+
(image.grade() < 7)
|
| 38 |
+
and (image.year() != 2023)
|
| 39 |
+
and (100.0 <= image.long() <= 180.0)
|
| 40 |
+
) # and (image.mask_1_percent() < self.corruption_ceiling_pct))
|
| 41 |
+
|
| 42 |
+
def transform_func(image_ray):
|
| 43 |
+
image_ray = np.clip(
|
| 44 |
+
image_ray,standardize_range[0],standardize_range[1]
|
| 45 |
+
)
|
| 46 |
+
image_ray = (image_ray - standardize_range[0]) / (
|
| 47 |
+
standardize_range[1] - standardize_range[0]
|
| 48 |
+
)
|
| 49 |
+
if downsample_size != (512, 512):
|
| 50 |
+
image_ray = torch.Tensor(image_ray)
|
| 51 |
+
image_ray = torch.reshape(
|
| 52 |
+
image_ray, [1, 1, image_ray.size()[0], image_ray.size()[1]]
|
| 53 |
+
)
|
| 54 |
+
image_ray = nn.functional.interpolate(
|
| 55 |
+
image_ray,
|
| 56 |
+
size=downsample_size,
|
| 57 |
+
mode="bilinear",
|
| 58 |
+
align_corners=False,
|
| 59 |
+
)
|
| 60 |
+
image_ray = torch.reshape(
|
| 61 |
+
image_ray, [image_ray.size()[2], image_ray.size()[3]]
|
| 62 |
+
)
|
| 63 |
+
image_ray = image_ray.numpy()
|
| 64 |
+
return image_ray
|
| 65 |
+
|
| 66 |
+
dataset = DigitalTyphoonDataset(
|
| 67 |
+
str(images_path),
|
| 68 |
+
str(track_path),
|
| 69 |
+
str(metadata_path),
|
| 70 |
+
"pressure",
|
| 71 |
+
load_data_into_memory='all_data',
|
| 72 |
+
filter_func=image_filter,
|
| 73 |
+
transform_func=transform_func,
|
| 74 |
+
spectrum="Infrared",
|
| 75 |
+
verbose=False,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
_,test_old = loading.load(0,dataset,batch_size,num_workers,type_save)
|
| 80 |
+
_,test_recent = loading.load(1,dataset,batch_size,num_workers,type_save)
|
| 81 |
+
_,test_now = loading.load(2,dataset,batch_size,num_workers,type_save)
|
| 82 |
+
|
| 83 |
+
# Test
|
| 84 |
+
|
| 85 |
+
trainer_old = pl.Trainer(
|
| 86 |
+
logger=logger_old,
|
| 87 |
+
accelerator=config.ACCELERATOR,
|
| 88 |
+
devices=config.DEVICE,
|
| 89 |
+
max_epochs=config.MAX_EPOCHS,
|
| 90 |
+
default_root_dir=config.LOG_DIR,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
trainer_recent = pl.Trainer(
|
| 94 |
+
logger=logger_recent,
|
| 95 |
+
accelerator=config.ACCELERATOR,
|
| 96 |
+
devices=config.DEVICE,
|
| 97 |
+
max_epochs=config.MAX_EPOCHS,
|
| 98 |
+
default_root_dir=config.LOG_DIR,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
trainer_now = pl.Trainer(
|
| 102 |
+
logger=logger_now,
|
| 103 |
+
accelerator=config.ACCELERATOR,
|
| 104 |
+
devices=config.DEVICE,
|
| 105 |
+
max_epochs=config.MAX_EPOCHS,
|
| 106 |
+
default_root_dir=config.LOG_DIR,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
version_dir_old = 'tb_logs/resnet_train_old'
|
| 110 |
+
version_dir_recent = 'tb_logs/resnet_train_recent'
|
| 111 |
+
version_dir_now = 'tb_logs/resnet_train_now'
|
| 112 |
+
|
| 113 |
+
if type_save == 'same_size':
|
| 114 |
+
version_dir_old += '_same'
|
| 115 |
+
version_dir_recent += '_same'
|
| 116 |
+
version_dir_now += '_same'
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
with open("log.txt","a+") as file :
|
| 122 |
+
file.write("\n------------------------------------------------------------ \n")
|
| 123 |
+
for i in versions:
|
| 124 |
+
|
| 125 |
+
with open("log.txt","a+") as file :
|
| 126 |
+
file.write(f"\nVersion : {i} \n")
|
| 127 |
+
version_path = f'/version_{i}/checkpoints/'
|
| 128 |
+
_,_,filename_old = next(os.walk(version_dir_old + version_path))
|
| 129 |
+
_,_,filename_recent = next(os.walk(version_dir_recent + version_path))
|
| 130 |
+
_,_,filename_now = next(os.walk(version_dir_now+ version_path))
|
| 131 |
+
model_old = LightningResnetReg.load_from_checkpoint(version_dir_old + version_path + filename_old[0])
|
| 132 |
+
model_recent = LightningResnetReg.load_from_checkpoint(version_dir_recent + version_path + filename_recent[0])
|
| 133 |
+
model_now = LightningResnetReg.load_from_checkpoint(version_dir_now + version_path + filename_now[0])
|
| 134 |
+
|
| 135 |
+
print("Testing <2005")
|
| 136 |
+
with open("log.txt","a+") as file :
|
| 137 |
+
file.write("Testing <2005 \n")
|
| 138 |
+
print(" on <2005 : ")
|
| 139 |
+
trainer_old.test(model_old, test_old)
|
| 140 |
+
print(" on >2005 : ")
|
| 141 |
+
trainer_old.test(model_old, test_recent)
|
| 142 |
+
print(" on >2015 : ")
|
| 143 |
+
trainer_old.test(model_old, test_now)
|
| 144 |
+
|
| 145 |
+
print("Testing >2005")
|
| 146 |
+
with open("log.txt","a+") as file :
|
| 147 |
+
file.write("Testing >2005\n")
|
| 148 |
+
print(" on <2005 : ")
|
| 149 |
+
trainer_recent.test(model_recent, test_old)
|
| 150 |
+
print(" on >2005 : ")
|
| 151 |
+
trainer_recent.test(model_recent, test_recent)
|
| 152 |
+
print(" on >2015 : ")
|
| 153 |
+
trainer_recent.test(model_recent, test_now)
|
| 154 |
+
|
| 155 |
+
print("Testing >2015")
|
| 156 |
+
with open("log.txt","a+") as file :
|
| 157 |
+
file.write("Testing >2015\n")
|
| 158 |
+
print(" on <2005 : ")
|
| 159 |
+
trainer_now.test(model_now, test_old)
|
| 160 |
+
print(" on >2005 : ")
|
| 161 |
+
trainer_now.test(model_now, test_recent)
|
| 162 |
+
print(" on >2015 : ")
|
| 163 |
+
trainer_now.test(model_now, test_now)
|
| 164 |
+
print(f"Run {i} done")
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
if __name__ == "__main__":
|
| 168 |
+
main()
|
train_split.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytorch_lightning as pl
|
| 2 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
| 3 |
+
from lightning_resnetReg import LightningResnetReg
|
| 4 |
+
import config
|
| 5 |
+
import loading
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from DigitalTyphoonDataloader.DigitalTyphoonDataset import DigitalTyphoonDataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def main():
|
| 16 |
+
logger_old = TensorBoardLogger("tb_logs", name="resnet_train_old_same")
|
| 17 |
+
logger_recent = TensorBoardLogger("tb_logs", name="resnet_train_recent_same")
|
| 18 |
+
logger_now = TensorBoardLogger("tb_logs", name="resnet_train_now_same")
|
| 19 |
+
|
| 20 |
+
# Set up data
|
| 21 |
+
batch_size=config.BATCH_SIZE
|
| 22 |
+
num_workers=config.NUM_WORKERS
|
| 23 |
+
standardize_range=config.STANDARDIZE_RANGE
|
| 24 |
+
downsample_size=config.DOWNSAMPLE_SIZE
|
| 25 |
+
type_save = config.TYPE_SAVE
|
| 26 |
+
nb_runs = config.NB_RUNS
|
| 27 |
+
|
| 28 |
+
data_path = Path("/app/datasets/wnp/")
|
| 29 |
+
images_path = str(data_path / "image") + "/"
|
| 30 |
+
track_path = str(data_path / "track") + "/"
|
| 31 |
+
metadata_path = str(data_path / "metadata.json")
|
| 32 |
+
|
| 33 |
+
def image_filter(image):
|
| 34 |
+
return (
|
| 35 |
+
(image.grade() < 7)
|
| 36 |
+
and (image.year() != 2023)
|
| 37 |
+
and (100.0 <= image.long() <= 180.0)
|
| 38 |
+
) # and (image.mask_1_percent() < self.corruption_ceiling_pct))
|
| 39 |
+
|
| 40 |
+
def transform_func(image_ray):
|
| 41 |
+
image_ray = np.clip(
|
| 42 |
+
image_ray,standardize_range[0],standardize_range[1]
|
| 43 |
+
)
|
| 44 |
+
image_ray = (image_ray - standardize_range[0]) / (
|
| 45 |
+
standardize_range[1] - standardize_range[0]
|
| 46 |
+
)
|
| 47 |
+
if downsample_size != (512, 512):
|
| 48 |
+
image_ray = torch.Tensor(image_ray)
|
| 49 |
+
image_ray = torch.reshape(
|
| 50 |
+
image_ray, [1, 1, image_ray.size()[0], image_ray.size()[1]]
|
| 51 |
+
)
|
| 52 |
+
image_ray = nn.functional.interpolate(
|
| 53 |
+
image_ray,
|
| 54 |
+
size=downsample_size,
|
| 55 |
+
mode="bilinear",
|
| 56 |
+
align_corners=False,
|
| 57 |
+
)
|
| 58 |
+
image_ray = torch.reshape(
|
| 59 |
+
image_ray, [image_ray.size()[2], image_ray.size()[3]]
|
| 60 |
+
)
|
| 61 |
+
image_ray = image_ray.numpy()
|
| 62 |
+
return image_ray
|
| 63 |
+
|
| 64 |
+
dataset = DigitalTyphoonDataset(
|
| 65 |
+
str(images_path),
|
| 66 |
+
str(track_path),
|
| 67 |
+
str(metadata_path),
|
| 68 |
+
"pressure",
|
| 69 |
+
load_data_into_memory='all_data',
|
| 70 |
+
filter_func=image_filter,
|
| 71 |
+
transform_func=transform_func,
|
| 72 |
+
spectrum="Infrared",
|
| 73 |
+
verbose=False,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
train_old,test_old = loading.load(0,dataset,batch_size,num_workers,type_save)
|
| 77 |
+
train_recent,test_recent = loading.load(1,dataset,batch_size,num_workers,type_save)
|
| 78 |
+
train_now,test_now = loading.load(2,dataset,batch_size,num_workers,type_save)
|
| 79 |
+
|
| 80 |
+
# Train
|
| 81 |
+
|
| 82 |
+
model_old = LightningResnetReg(
|
| 83 |
+
learning_rate=config.LEARNING_RATE,
|
| 84 |
+
weights=config.WEIGHTS,
|
| 85 |
+
num_classes=config.NUM_CLASSES,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
model_recent = LightningResnetReg(
|
| 90 |
+
learning_rate=config.LEARNING_RATE,
|
| 91 |
+
weights=config.WEIGHTS,
|
| 92 |
+
num_classes=config.NUM_CLASSES,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
model_now = LightningResnetReg(
|
| 96 |
+
learning_rate=config.LEARNING_RATE,
|
| 97 |
+
weights=config.WEIGHTS,
|
| 98 |
+
num_classes=config.NUM_CLASSES,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
trainer_old = pl.Trainer(
|
| 103 |
+
logger=logger_old,
|
| 104 |
+
accelerator=config.ACCELERATOR,
|
| 105 |
+
devices=config.DEVICE,
|
| 106 |
+
max_epochs=config.MAX_EPOCHS,
|
| 107 |
+
default_root_dir=config.LOG_DIR,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
trainer_recent = pl.Trainer(
|
| 111 |
+
logger=logger_recent,
|
| 112 |
+
accelerator=config.ACCELERATOR,
|
| 113 |
+
devices=config.DEVICE,
|
| 114 |
+
max_epochs=config.MAX_EPOCHS,
|
| 115 |
+
default_root_dir=config.LOG_DIR,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
trainer_now = pl.Trainer(
|
| 119 |
+
logger=logger_now,
|
| 120 |
+
accelerator=config.ACCELERATOR,
|
| 121 |
+
devices=config.DEVICE,
|
| 122 |
+
max_epochs=config.MAX_EPOCHS,
|
| 123 |
+
default_root_dir=config.LOG_DIR,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
for i in range(nb_runs):
|
| 127 |
+
print("Training <2005")
|
| 128 |
+
trainer_old.fit(model_old, train_old, test_old)
|
| 129 |
+
|
| 130 |
+
print("Training >2005")
|
| 131 |
+
trainer_recent.fit(model_recent, train_recent, test_recent)
|
| 132 |
+
|
| 133 |
+
print("Training >2015")
|
| 134 |
+
trainer_now.fit(model_now, train_now, test_now)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
if __name__ == "__main__":
|
| 138 |
+
main()
|