jaredhwang commited on
Commit
02cdcbc
·
1 Parent(s): f3bf6d6

Migrate benchmark from https://github.com/kitamoto-lab/benchmarks/

Browse files
Files changed (9) hide show
  1. Dockerfile +19 -0
  2. FrameDatamodule.py +110 -0
  3. README.md +38 -0
  4. config.py +28 -0
  5. createdataset.py +178 -0
  6. lightning_resnetReg.py +149 -0
  7. loading.py +43 -0
  8. split_testing.py +168 -0
  9. train_split.py +138 -0
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ubuntu
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && \
6
+ apt-get upgrade -y && \
7
+ apt-get install -y git && \
8
+ apt-get install -y libopenmpi-dev && \
9
+ apt-get install -y python3-pip && \
10
+ git clone https://github.com/kitamoto-lab/pyphoon2.git && \
11
+ cd pyphoon2 && \
12
+ pip3 install . && \
13
+ pip3 install tqdm && \
14
+ pip3 install scikit-learn && \
15
+ pip3 install matplotlib && \
16
+ pip3 install seaborn && \
17
+ pip3 install lightning && \
18
+ pip3 install tensorboardX
19
+
FrameDatamodule.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import pytorch_lightning as pl
4
+ from torch.utils.data import DataLoader
5
+
6
+ from pathlib import Path
7
+ import numpy as np
8
+
9
+ from pyphoon2.DigitalTyphoonDataset import DigitalTyphoonDataset
10
+
11
+
12
+ class TyphoonDataModule(pl.LightningDataModule):
13
+ def __init__(
14
+ self,
15
+ dataroot,
16
+ batch_size,
17
+ num_workers,
18
+ labels = 'grade',
19
+ split_by="sequence",
20
+ load_data=False,
21
+ dataset_split=(0.8, 0.1, 0.1),
22
+ standardize_range=(150, 350),
23
+ downsample_size=(224, 224),
24
+ corruption_ceiling_pct=100,
25
+ ):
26
+ super().__init__()
27
+
28
+ self.batch_size = batch_size
29
+ self.num_workers = num_workers
30
+
31
+ data_path = Path(dataroot)
32
+ self.images_path = str(data_path / "image") + "/"
33
+ self.track_path = str(data_path / "track") + "/"
34
+ self.metadata_path = str(data_path / "metadata.json")
35
+ self.load_data = load_data
36
+ self.split_by = split_by
37
+ self.labels = labels
38
+
39
+ self.dataset_split = dataset_split
40
+ self.standardize_range = standardize_range
41
+ self.downsample_size = downsample_size
42
+
43
+ self.corruption_ceiling_pct = corruption_ceiling_pct
44
+
45
+ def setup(self, stage):
46
+ # Load Dataset
47
+ dataset = DigitalTyphoonDataset(
48
+ str(self.images_path),
49
+ str(self.track_path),
50
+ str(self.metadata_path),
51
+ self.labels,
52
+ load_data_into_memory=self.load_data,
53
+ filter_func=self.image_filter,
54
+ transform_func=self.transform_func,
55
+ spectrum="Infrared",
56
+ verbose=False,
57
+ )
58
+
59
+ self.train_set, self.val_set, _ = dataset.random_split(
60
+ self.dataset_split, split_by=self.split_by
61
+ )
62
+
63
+ def train_dataloader(self):
64
+ return DataLoader(
65
+ self.train_set,
66
+ batch_size=self.batch_size,
67
+ num_workers=self.num_workers,
68
+ shuffle=True,
69
+ )
70
+
71
+ def val_dataloader(self):
72
+ return DataLoader(
73
+ self.val_set,
74
+ batch_size=self.batch_size,
75
+ num_workers=self.num_workers,
76
+ shuffle=False,
77
+ )
78
+
79
+ def image_filter(self, image):
80
+ return (
81
+ (image.grade() < 6)
82
+ and (image.grade() > 2)
83
+ and (image.interpolated() == False)
84
+ and (image.year() != 2023)
85
+ and (100.0 <= image.long() <= 180.0)
86
+ ) # and (image.mask_1_percent() < self.corruption_ceiling_pct))
87
+
88
+ def transform_func(self, image_ray):
89
+ image_ray = np.clip(
90
+ image_ray, self.standardize_range[0], self.standardize_range[1]
91
+ )
92
+ image_ray = (image_ray - self.standardize_range[0]) / (
93
+ self.standardize_range[1] - self.standardize_range[0]
94
+ )
95
+ if self.downsample_size != (512, 512):
96
+ image_ray = torch.Tensor(image_ray)
97
+ image_ray = torch.reshape(
98
+ image_ray, [1, 1, image_ray.size()[0], image_ray.size()[1]]
99
+ )
100
+ image_ray = nn.functional.interpolate(
101
+ image_ray,
102
+ size=self.downsample_size,
103
+ mode="bilinear",
104
+ align_corners=False,
105
+ )
106
+ image_ray = torch.reshape(
107
+ image_ray, [image_ray.size()[2], image_ray.size()[3]]
108
+ )
109
+ image_ray = image_ray.numpy()
110
+ return image_ray
README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Instructions to run
3
+
4
+ #### Docker
5
+ All of the below commands should be run in a Docker container built using the Dockerfile in the repo, with the data and repo being exposed as volumes in the container.
6
+
7
+ To build:
8
+
9
+ ```docker build -t benchmarks_img .```
10
+
11
+ To run an interactive shell:
12
+
13
+ ```docker run -it --shm-size=2G --gpus all -v /path/to/neurips2023-benchmarks:/neurips2023-benchmarks -v /path/to/datasets/:/data benchmarks_img```
14
+
15
+
16
+ ### Reanalysis Task
17
+ Every command should be run in the reanalysis folder. The path to this folder and to the data should be provided in the config.py file.
18
+
19
+ #### Create buckets
20
+ First, you have to split and save the dataset into 3 buckets according to the type of splitting refered in the config.py file ('standard' for standard splitting between before 2005 / between 2005 and 2015 / after 2015, 'same_size' for the same splitting but with a equal number of sequences per bucket).
21
+ ```
22
+ python3 createdataset.py
23
+ ```
24
+ This will create a folder (named 'save' or 'save_same') with 6 .txt file containing the id of the sequences used for training and testing in each bucket.
25
+
26
+ #### Train
27
+ You can now train for a number of runs (called version in the logs) and epochs specified in the config.py file.
28
+ ```
29
+ python3 train_split.py
30
+ ```
31
+ A tensorboard log while be created for each run with each bucket in the tb_logs.
32
+
33
+ #### Test
34
+ After specifing a list of versions in the config.py file, you'll be able to test the model.
35
+ ```
36
+ python3 split_testing.py
37
+ ```
38
+ The accuracy (RMSE in hPa) will be displayed on the terminal but also written in a log.txt file in the directory ```reanalysis```.
config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # Training Hyperparameters
4
+ LEARNING_RATE = 0.0001
5
+ BATCH_SIZE = 16
6
+ NUM_WORKERS = 16
7
+ MAX_EPOCHS = 101
8
+ NB_RUNS = 5
9
+ TESTING_VERSION = (0,1,2,3,4)
10
+
11
+
12
+
13
+ # DATASET
14
+ WEIGHTS = None
15
+ LABELS = 'pressure'
16
+ SPLIT_BY = 'sequence'
17
+ LOAD_DATA = 'all_data'
18
+ DATASET_SPLIT = (0.8, 0.1, 0.1)
19
+ STANDARDIZE_RANGE = (170, 350)
20
+ DOWNSAMPLE_SIZE = (224, 224)
21
+ NUM_CLASSES = 1
22
+ TYPE_SAVE = 'standard' #'standard' or 'same_size'
23
+
24
+ # Computation
25
+ ACCELERATOR = 'gpu' if torch.cuda.is_available() else 'cpu'
26
+ DEVICE = [0]
27
+ DATA_DIR = '/app/datasets/wnp/'
28
+ LOG_DIR = "/app/pyphoon2/reanalysis/tb_logs"
createdataset.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import config
2
+ import torch
3
+ from torch import nn
4
+ from pathlib import Path
5
+ import numpy as np
6
+ from pyphoon2.DigitalTyphoonDataset import DigitalTyphoonDataset
7
+ import random
8
+ import os
9
+
10
+ dataroot = config.DATA_DIR
11
+ batch_size=config.BATCH_SIZE
12
+ num_workers=config.NUM_WORKERS
13
+ split_by=config.SPLIT_BY
14
+ load_data=config.LOAD_DATA
15
+ dataset_split=config.DATASET_SPLIT
16
+ standardize_range=config.STANDARDIZE_RANGE
17
+ downsample_size=config.DOWNSAMPLE_SIZE
18
+ type_save=config.TYPE_SAVE
19
+
20
+ data_path = Path(dataroot)
21
+ images_path = str(data_path / "image") + "/"
22
+ track_path = str(data_path / "track") + "/"
23
+ metadata_path = str(data_path / "metadata.json")
24
+
25
+ def image_filter(image):
26
+ return (
27
+ (image.grade() < 7)
28
+ and (image.year() != 2023)
29
+ and (100.0 <= image.long() <= 180.0)
30
+ ) # and (image.mask_1_percent() < self.corruption_ceiling_pct))
31
+
32
+ def transform_func(image_ray):
33
+ image_ray = np.clip(
34
+ image_ray,standardize_range[0],standardize_range[1]
35
+ )
36
+ image_ray = (image_ray - standardize_range[0]) / (
37
+ standardize_range[1] - standardize_range[0]
38
+ )
39
+ if downsample_size != (512, 512):
40
+ image_ray = torch.Tensor(image_ray)
41
+ image_ray = torch.reshape(
42
+ image_ray, [1, 1, image_ray.size()[0], image_ray.size()[1]]
43
+ )
44
+ image_ray = nn.functional.interpolate(
45
+ image_ray,
46
+ size=downsample_size,
47
+ mode="bilinear",
48
+ align_corners=False,
49
+ )
50
+ image_ray = torch.reshape(
51
+ image_ray, [image_ray.size()[2], image_ray.size()[3]]
52
+ )
53
+ image_ray = image_ray.numpy()
54
+ return image_ray
55
+
56
+ dataset = DigitalTyphoonDataset(
57
+ str(images_path),
58
+ str(track_path),
59
+ str(metadata_path),
60
+ "pressure",
61
+ load_data_into_memory='all_data',
62
+ filter_func=image_filter,
63
+ transform_func=transform_func,
64
+ spectrum="Infrared",
65
+ verbose=False,
66
+ )
67
+
68
+
69
+ years = dataset.get_years()
70
+ old=[]
71
+ recent=[]
72
+ now=[]
73
+
74
+ #splitting years in 3 buckets
75
+ for i in years :
76
+ if i < 2005 :
77
+ old.append(i)
78
+ else :
79
+ if i < 2015:
80
+ recent.append(i)
81
+ else :
82
+ now.append(i)
83
+
84
+
85
+ old_data=[]
86
+ recent_data=[]
87
+ now_data=[]
88
+
89
+ #getting the ids from years
90
+ for year in old :
91
+ old_data.extend(dataset.get_seq_ids_from_year(year))
92
+
93
+ for year in recent :
94
+ recent_data.extend(dataset.get_seq_ids_from_year(year))
95
+
96
+ for year in now :
97
+ now_data.extend(dataset.get_seq_ids_from_year(year))
98
+
99
+ old_train , old_val = [],[]
100
+ recent_train , recent_val = [],[]
101
+ now_train , now_val = [],[]
102
+
103
+ #shuffling and splitting 80/20
104
+ random.shuffle(old_data)
105
+ random.shuffle(now_data)
106
+ random.shuffle(recent_data)
107
+
108
+ l=len(old_data)
109
+ for i in range(l):
110
+ if i<l*0.8:
111
+ old_train.append(old_data[i])
112
+ else:
113
+ old_val.append(old_data[i])
114
+
115
+ l=len(recent_data)
116
+ for i in range(l):
117
+ if i<l*0.8:
118
+ recent_train.append(recent_data[i])
119
+ else:
120
+ recent_val.append(recent_data[i])
121
+
122
+ l=len(now_data)
123
+ for i in range(l):
124
+ if i<l*0.8:
125
+ now_train.append(now_data[i])
126
+ else:
127
+ now_val.append(now_data[i])
128
+
129
+
130
+
131
+ #writting in file depending on which format
132
+ if(type_save=="standard"):
133
+ if not(os.path.exists('./save')): os.mkdir('./save')
134
+ with open('save/old_train.txt','w+') as file:
135
+ for id in old_train:
136
+ file.write(id+"\n")
137
+
138
+ with open('save/old_val.txt','w+') as file:
139
+ for id in old_val :
140
+ file.write(id+"\n")
141
+
142
+ with open('save/recent_train.txt','w+') as file:
143
+ for id in recent_train:
144
+ file.write(id+"\n")
145
+
146
+ with open('save/recent_val.txt','w+') as file:
147
+ for id in recent_val:
148
+ file.write(id+"\n")
149
+
150
+ with open('save/now_train.txt','w+') as file:
151
+ for id in now_train:
152
+ file.write(id+"\n")
153
+
154
+ with open('save/now_val.txt','w+') as file:
155
+ for id in now_val:
156
+ file.write(id+"\n")
157
+
158
+ if(type_save=="same_size"):
159
+ if not(os.path.exists('./save_same')): os.mkdir('./save_same')
160
+ with(
161
+ open('save_same/old_train.txt','w+') as train1,
162
+ open('save_same/old_val.txt','w+') as test1,
163
+ open('save_same/recent_train.txt','w+') as train2,
164
+ open('save_same/recent_val.txt','w+') as test2,
165
+ open('save_same/now_train.txt','w+') as train3,
166
+ open('save_same/now_val.txt','w+') as test3,
167
+ ):
168
+ for i in range(min(len(old_train),len(recent_train),len(now_train))):
169
+ train1.write(old_train[i]+'\n')
170
+ train2.write(recent_train[i]+'\n')
171
+ train3.write(now_train[i]+'\n')
172
+ for i in range(min(len(old_val),len(recent_val),len(now_val))):
173
+ test1.write(old_val[i]+'\n')
174
+ test2.write(recent_val[i]+'\n')
175
+ test3.write(now_val[i]+'\n')
176
+
177
+ print("Saving Done !")
178
+
lightning_resnetReg.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.optim as optim
4
+ from torchvision.models import resnet18
5
+ import pytorch_lightning as pl
6
+ from torchmetrics import MeanSquaredError
7
+
8
+
9
+
10
+ class LightningResnetReg(pl.LightningModule):
11
+ def __init__(self, learning_rate, weights, num_classes):
12
+ super().__init__()
13
+ self.save_hyperparameters()
14
+
15
+ self.model = resnet18(num_classes=1, weights=weights)
16
+ self.model.conv1 = nn.Conv2d(
17
+ 1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
18
+ )
19
+ self.model.fc = nn.Linear(in_features=512, out_features=1, bias=True)
20
+
21
+ self.learning_rate = learning_rate
22
+ self.loss_fn = nn.MSELoss()
23
+ self.accuracy = MeanSquaredError(squared = False)
24
+ self.compt = 1
25
+ self.predicted_labels = []
26
+ self.truth_labels = []
27
+
28
+
29
+ def forward(self, images):
30
+ images = torch.Tensor(images).float()
31
+ images = torch.reshape(
32
+ images, [images.size()[0], 1, images.size()[1], images.size()[2]]
33
+ )
34
+ output = self.model(images)
35
+ return output
36
+
37
+ def training_step(self, batch, batch_idx):
38
+ loss, outputs, labels = self._common_step(batch)
39
+ accuracy = self.accuracy(outputs, labels)
40
+ self.log_dict({
41
+ "train_loss": loss,
42
+ "train_RMSE": accuracy
43
+ },
44
+ on_step=False,
45
+ on_epoch=True,
46
+ sync_dist=True,
47
+ )
48
+ return loss
49
+
50
+ def validation_step(self, batch, batch_idx):
51
+ loss, outputs, labels = self._common_step(batch)
52
+ self.log("validation_loss", loss,
53
+ on_step=False, on_epoch=True, sync_dist=True)
54
+ self.predicted_labels.append(outputs)
55
+ self.truth_labels.append(labels.float())
56
+ return loss
57
+
58
+ def test_step(self, batch, batch_idx):
59
+ loss, outputs, labels = self._common_step(batch)
60
+ self.log("test_loss", loss,
61
+ on_step=False, on_epoch=True, sync_dist=True)
62
+ self.predicted_labels.append(outputs)
63
+ self.truth_labels.append(labels.float())
64
+ return loss
65
+
66
+ def _common_step(self, batch):
67
+ images, labels = batch
68
+ labels = labels - 2
69
+ labels = torch.reshape(labels, [labels.size()[0],1])
70
+ outputs = self.forward(images)
71
+ loss = self.loss_fn(outputs, labels.float())
72
+ return loss, outputs, labels
73
+
74
+ def predict_step(self, batch):
75
+ images, labels = batch
76
+ labels = labels - 2
77
+ labels = torch.reshape(labels, [labels.size()[0],1])
78
+ outputs = self.forward(images)
79
+ preds = outputs
80
+ return preds
81
+
82
+ def configure_optimizers(self):
83
+ return optim.SGD(self.parameters(), lr=self.learning_rate)
84
+
85
+ def on_validation_epoch_end(self):
86
+
87
+ tensorboard = self.logger.experiment
88
+ all_preds = torch.concat(self.predicted_labels)
89
+ all_truths = torch.concat(self.truth_labels)
90
+ all_couple = torch.cat((all_truths, all_preds), dim=1)
91
+ wind_values = torch.unique(all_truths)
92
+ pred_means = []
93
+ pred_std = []
94
+ pred_n = []
95
+ for value in wind_values:
96
+ # find all the couple (truth, preds) where truth == value and compute the mean of all the prediction for this value
97
+ m = torch.mean((all_couple[torch.where(all_couple[:,0] == value)][:,1].float()))
98
+ std = torch.std((all_couple[torch.where(all_couple[:,0] == value)][:,1].float()))
99
+ n = len(all_couple[torch.where(all_couple[:,0] == value)][:,1].float())
100
+ pred_means.append(m)
101
+ pred_std.append(std)
102
+ pred_n.append(n)
103
+
104
+ # Log regression line graph every 5 epochs
105
+ if(self.current_epoch %5 == 0 ):
106
+ for i in range(len(wind_values)):
107
+ tensorboard.add_scalars(f"epoch_{self.current_epoch}",{'pred_mean':pred_means[i],'truth':wind_values[i]},wind_values[i])
108
+ tensorboard.add_scalars(f"epoch_{self.current_epoch}_stats",{'pred_std':pred_std[i],'pred_n':pred_n[i]},wind_values[i])
109
+
110
+
111
+ self.log("validation_RMSE", self.accuracy(all_preds,all_truths),
112
+ on_step=False, on_epoch=True, sync_dist=True)
113
+ self.predicted_labels.clear() # free memory
114
+ self.truth_labels.clear()
115
+
116
+ def on_test_epoch_end(self):
117
+ tensorboard= self.logger.experiment
118
+
119
+ all_preds = torch.concat(self.predicted_labels)
120
+ all_truths = torch.concat(self.truth_labels)
121
+ all_couple = torch.cat((all_truths, all_preds), dim=1)
122
+ self.logger.experiment.add_embedding(all_couple, tag="couple_label_pred_ep" + str(self.compt) + ".tsv")
123
+ unique_values = torch.unique(all_truths)
124
+ pred_means = []
125
+ pred_std = []
126
+ pred_n = []
127
+ for value in unique_values:
128
+ # find all the couple (truth, preds) where truth == value and compute the mean of all the prediction for this value
129
+ m = torch.mean((all_couple[torch.where(all_couple[:,0] == value)][:,1].float()))
130
+ std = torch.std((all_couple[torch.where(all_couple[:,0] == value)][:,1].float()))
131
+ n = len(all_couple[torch.where(all_couple[:,0] == value)][:,1].float())
132
+ pred_means.append(m)
133
+ pred_std.append(std)
134
+ pred_n.append(n)
135
+
136
+ # Log regression line graph every 5 epochs
137
+ if(self.current_epoch %5 == 0 ):
138
+ for i in range(len(unique_values)):
139
+ tensorboard.add_scalars(f"test_{self.compt}",{'pred_mean':pred_means[i],'truth':unique_values[i]},unique_values[i])
140
+ tensorboard.add_scalars(f"test_{self.compt}_stats",{'pred_std':pred_std[i],'pred_n':pred_n[i]},unique_values[i])
141
+
142
+ Accuracy = self.accuracy(all_preds,all_truths)
143
+ self.log(f"test_{self.compt}_RMSE", Accuracy,
144
+ on_step=False, on_epoch=True, sync_dist=True)
145
+ with open("log.txt","a+") as file:
146
+ file.write(f"test_{self.compt}_RMSE : {Accuracy} \n")
147
+ self.predicted_labels.clear() # free memory
148
+ self.truth_labels.clear()
149
+ self.compt +=1
loading.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+
3
+ def load(type,dataset,batch_size,num_workers,type_save='standard'):
4
+ train, test = [],[]
5
+ if (type_save=='standard') :
6
+ file_dir = 'save/'
7
+ if (type_save=='same_size') :
8
+ file_dir = 'save_same/'
9
+
10
+ if type==0 :
11
+ with open(file_dir + 'old_train.txt','r') as file:
12
+ train_id=[line for line in file]
13
+ with open(file_dir + 'old_val.txt','r') as file:
14
+ test_id =[line for line in file]
15
+ if type==1 :
16
+ with open(file_dir + 'recent_train.txt','r') as file:
17
+ train_id=[line for line in file]
18
+ with open(file_dir + 'recent_val.txt','r') as file:
19
+ test_id =[line for line in file]
20
+ if type==2 :
21
+ with open(file_dir + 'now_train.txt','r') as file:
22
+ train_id=[line for line in file]
23
+ with open(file_dir + 'now_val.txt','r') as file:
24
+ test_id =[line for line in file]
25
+ if type==3 :
26
+ with open(file_dir + 'now_train.txt','r') as file:
27
+ train_id1=[line for line in file]
28
+ with open(file_dir + 'now_val.txt','r') as file:
29
+ test_id1 =[line for line in file]
30
+ with open(file_dir + 'recent_train.txt','r') as file:
31
+ train_id2=[line for line in file]
32
+ with open(file_dir + 'recent_val.txt','r') as file:
33
+ test_id2 =[line for line in file]
34
+ train_id = train_id1 +train_id2
35
+ test_id = test_id1+ test_id2
36
+
37
+ train_id = [x.replace('\n', '') for x in train_id]
38
+ test_id = [x.replace('\n','') for x in test_id]
39
+ train = DataLoader(dataset.images_from_sequences(train_id),batch_size= batch_size,num_workers=num_workers,shuffle=True)
40
+ test = DataLoader(dataset.images_from_sequences(test_id),batch_size= batch_size,num_workers=num_workers,shuffle=False)
41
+
42
+
43
+ return train, test
split_testing.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ from pytorch_lightning.loggers import TensorBoardLogger
3
+ from lightning_resnetReg import LightningResnetReg
4
+ import config
5
+ import loading
6
+ import torch
7
+ from torch import nn
8
+ import os
9
+
10
+ from pathlib import Path
11
+ import numpy as np
12
+
13
+ from pyphoon2.DigitalTyphoonDataset import DigitalTyphoonDataset
14
+
15
+ def main():
16
+ logger_old = TensorBoardLogger("tb_logs", name="resnet_test_old_same")
17
+ logger_recent = TensorBoardLogger("tb_logs", name="resnet_test_recent_same")
18
+ logger_now = TensorBoardLogger("tb_logs", name="resnet_test_now_same")
19
+
20
+ # Set up data
21
+ data_root = config.DATA_DIR
22
+ batch_size=config.BATCH_SIZE
23
+ num_workers=config.NUM_WORKERS
24
+ standardize_range=config.STANDARDIZE_RANGE
25
+ downsample_size=config.DOWNSAMPLE_SIZE
26
+ type_save = config.TYPE_SAVE
27
+ versions = config.TESTING_VERSION
28
+
29
+
30
+ data_path = Path(data_root)
31
+ images_path = str(data_path / "image") + "/"
32
+ track_path = str(data_path / "track") + "/"
33
+ metadata_path = str(data_path / "metadata.json")
34
+
35
+ def image_filter(image):
36
+ return (
37
+ (image.grade() < 7)
38
+ and (image.year() != 2023)
39
+ and (100.0 <= image.long() <= 180.0)
40
+ ) # and (image.mask_1_percent() < self.corruption_ceiling_pct))
41
+
42
+ def transform_func(image_ray):
43
+ image_ray = np.clip(
44
+ image_ray,standardize_range[0],standardize_range[1]
45
+ )
46
+ image_ray = (image_ray - standardize_range[0]) / (
47
+ standardize_range[1] - standardize_range[0]
48
+ )
49
+ if downsample_size != (512, 512):
50
+ image_ray = torch.Tensor(image_ray)
51
+ image_ray = torch.reshape(
52
+ image_ray, [1, 1, image_ray.size()[0], image_ray.size()[1]]
53
+ )
54
+ image_ray = nn.functional.interpolate(
55
+ image_ray,
56
+ size=downsample_size,
57
+ mode="bilinear",
58
+ align_corners=False,
59
+ )
60
+ image_ray = torch.reshape(
61
+ image_ray, [image_ray.size()[2], image_ray.size()[3]]
62
+ )
63
+ image_ray = image_ray.numpy()
64
+ return image_ray
65
+
66
+ dataset = DigitalTyphoonDataset(
67
+ str(images_path),
68
+ str(track_path),
69
+ str(metadata_path),
70
+ "pressure",
71
+ load_data_into_memory='all_data',
72
+ filter_func=image_filter,
73
+ transform_func=transform_func,
74
+ spectrum="Infrared",
75
+ verbose=False,
76
+ )
77
+
78
+
79
+ _,test_old = loading.load(0,dataset,batch_size,num_workers,type_save)
80
+ _,test_recent = loading.load(1,dataset,batch_size,num_workers,type_save)
81
+ _,test_now = loading.load(2,dataset,batch_size,num_workers,type_save)
82
+
83
+ # Test
84
+
85
+ trainer_old = pl.Trainer(
86
+ logger=logger_old,
87
+ accelerator=config.ACCELERATOR,
88
+ devices=config.DEVICE,
89
+ max_epochs=config.MAX_EPOCHS,
90
+ default_root_dir=config.LOG_DIR,
91
+ )
92
+
93
+ trainer_recent = pl.Trainer(
94
+ logger=logger_recent,
95
+ accelerator=config.ACCELERATOR,
96
+ devices=config.DEVICE,
97
+ max_epochs=config.MAX_EPOCHS,
98
+ default_root_dir=config.LOG_DIR,
99
+ )
100
+
101
+ trainer_now = pl.Trainer(
102
+ logger=logger_now,
103
+ accelerator=config.ACCELERATOR,
104
+ devices=config.DEVICE,
105
+ max_epochs=config.MAX_EPOCHS,
106
+ default_root_dir=config.LOG_DIR,
107
+ )
108
+
109
+ version_dir_old = 'tb_logs/resnet_train_old'
110
+ version_dir_recent = 'tb_logs/resnet_train_recent'
111
+ version_dir_now = 'tb_logs/resnet_train_now'
112
+
113
+ if type_save == 'same_size':
114
+ version_dir_old += '_same'
115
+ version_dir_recent += '_same'
116
+ version_dir_now += '_same'
117
+
118
+
119
+
120
+
121
+ with open("log.txt","a+") as file :
122
+ file.write("\n------------------------------------------------------------ \n")
123
+ for i in versions:
124
+
125
+ with open("log.txt","a+") as file :
126
+ file.write(f"\nVersion : {i} \n")
127
+ version_path = f'/version_{i}/checkpoints/'
128
+ _,_,filename_old = next(os.walk(version_dir_old + version_path))
129
+ _,_,filename_recent = next(os.walk(version_dir_recent + version_path))
130
+ _,_,filename_now = next(os.walk(version_dir_now+ version_path))
131
+ model_old = LightningResnetReg.load_from_checkpoint(version_dir_old + version_path + filename_old[0])
132
+ model_recent = LightningResnetReg.load_from_checkpoint(version_dir_recent + version_path + filename_recent[0])
133
+ model_now = LightningResnetReg.load_from_checkpoint(version_dir_now + version_path + filename_now[0])
134
+
135
+ print("Testing <2005")
136
+ with open("log.txt","a+") as file :
137
+ file.write("Testing <2005 \n")
138
+ print(" on <2005 : ")
139
+ trainer_old.test(model_old, test_old)
140
+ print(" on >2005 : ")
141
+ trainer_old.test(model_old, test_recent)
142
+ print(" on >2015 : ")
143
+ trainer_old.test(model_old, test_now)
144
+
145
+ print("Testing >2005")
146
+ with open("log.txt","a+") as file :
147
+ file.write("Testing >2005\n")
148
+ print(" on <2005 : ")
149
+ trainer_recent.test(model_recent, test_old)
150
+ print(" on >2005 : ")
151
+ trainer_recent.test(model_recent, test_recent)
152
+ print(" on >2015 : ")
153
+ trainer_recent.test(model_recent, test_now)
154
+
155
+ print("Testing >2015")
156
+ with open("log.txt","a+") as file :
157
+ file.write("Testing >2015\n")
158
+ print(" on <2005 : ")
159
+ trainer_now.test(model_now, test_old)
160
+ print(" on >2005 : ")
161
+ trainer_now.test(model_now, test_recent)
162
+ print(" on >2015 : ")
163
+ trainer_now.test(model_now, test_now)
164
+ print(f"Run {i} done")
165
+
166
+
167
+ if __name__ == "__main__":
168
+ main()
train_split.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ from pytorch_lightning.loggers import TensorBoardLogger
3
+ from lightning_resnetReg import LightningResnetReg
4
+ import config
5
+ import loading
6
+ import torch
7
+ from torch import nn
8
+
9
+ from pathlib import Path
10
+ import numpy as np
11
+
12
+ from DigitalTyphoonDataloader.DigitalTyphoonDataset import DigitalTyphoonDataset
13
+
14
+
15
+ def main():
16
+ logger_old = TensorBoardLogger("tb_logs", name="resnet_train_old_same")
17
+ logger_recent = TensorBoardLogger("tb_logs", name="resnet_train_recent_same")
18
+ logger_now = TensorBoardLogger("tb_logs", name="resnet_train_now_same")
19
+
20
+ # Set up data
21
+ batch_size=config.BATCH_SIZE
22
+ num_workers=config.NUM_WORKERS
23
+ standardize_range=config.STANDARDIZE_RANGE
24
+ downsample_size=config.DOWNSAMPLE_SIZE
25
+ type_save = config.TYPE_SAVE
26
+ nb_runs = config.NB_RUNS
27
+
28
+ data_path = Path("/app/datasets/wnp/")
29
+ images_path = str(data_path / "image") + "/"
30
+ track_path = str(data_path / "track") + "/"
31
+ metadata_path = str(data_path / "metadata.json")
32
+
33
+ def image_filter(image):
34
+ return (
35
+ (image.grade() < 7)
36
+ and (image.year() != 2023)
37
+ and (100.0 <= image.long() <= 180.0)
38
+ ) # and (image.mask_1_percent() < self.corruption_ceiling_pct))
39
+
40
+ def transform_func(image_ray):
41
+ image_ray = np.clip(
42
+ image_ray,standardize_range[0],standardize_range[1]
43
+ )
44
+ image_ray = (image_ray - standardize_range[0]) / (
45
+ standardize_range[1] - standardize_range[0]
46
+ )
47
+ if downsample_size != (512, 512):
48
+ image_ray = torch.Tensor(image_ray)
49
+ image_ray = torch.reshape(
50
+ image_ray, [1, 1, image_ray.size()[0], image_ray.size()[1]]
51
+ )
52
+ image_ray = nn.functional.interpolate(
53
+ image_ray,
54
+ size=downsample_size,
55
+ mode="bilinear",
56
+ align_corners=False,
57
+ )
58
+ image_ray = torch.reshape(
59
+ image_ray, [image_ray.size()[2], image_ray.size()[3]]
60
+ )
61
+ image_ray = image_ray.numpy()
62
+ return image_ray
63
+
64
+ dataset = DigitalTyphoonDataset(
65
+ str(images_path),
66
+ str(track_path),
67
+ str(metadata_path),
68
+ "pressure",
69
+ load_data_into_memory='all_data',
70
+ filter_func=image_filter,
71
+ transform_func=transform_func,
72
+ spectrum="Infrared",
73
+ verbose=False,
74
+ )
75
+
76
+ train_old,test_old = loading.load(0,dataset,batch_size,num_workers,type_save)
77
+ train_recent,test_recent = loading.load(1,dataset,batch_size,num_workers,type_save)
78
+ train_now,test_now = loading.load(2,dataset,batch_size,num_workers,type_save)
79
+
80
+ # Train
81
+
82
+ model_old = LightningResnetReg(
83
+ learning_rate=config.LEARNING_RATE,
84
+ weights=config.WEIGHTS,
85
+ num_classes=config.NUM_CLASSES,
86
+ )
87
+
88
+
89
+ model_recent = LightningResnetReg(
90
+ learning_rate=config.LEARNING_RATE,
91
+ weights=config.WEIGHTS,
92
+ num_classes=config.NUM_CLASSES,
93
+ )
94
+
95
+ model_now = LightningResnetReg(
96
+ learning_rate=config.LEARNING_RATE,
97
+ weights=config.WEIGHTS,
98
+ num_classes=config.NUM_CLASSES,
99
+ )
100
+
101
+
102
+ trainer_old = pl.Trainer(
103
+ logger=logger_old,
104
+ accelerator=config.ACCELERATOR,
105
+ devices=config.DEVICE,
106
+ max_epochs=config.MAX_EPOCHS,
107
+ default_root_dir=config.LOG_DIR,
108
+ )
109
+
110
+ trainer_recent = pl.Trainer(
111
+ logger=logger_recent,
112
+ accelerator=config.ACCELERATOR,
113
+ devices=config.DEVICE,
114
+ max_epochs=config.MAX_EPOCHS,
115
+ default_root_dir=config.LOG_DIR,
116
+ )
117
+
118
+ trainer_now = pl.Trainer(
119
+ logger=logger_now,
120
+ accelerator=config.ACCELERATOR,
121
+ devices=config.DEVICE,
122
+ max_epochs=config.MAX_EPOCHS,
123
+ default_root_dir=config.LOG_DIR,
124
+ )
125
+
126
+ for i in range(nb_runs):
127
+ print("Training <2005")
128
+ trainer_old.fit(model_old, train_old, test_old)
129
+
130
+ print("Training >2005")
131
+ trainer_recent.fit(model_recent, train_recent, test_recent)
132
+
133
+ print("Training >2015")
134
+ trainer_now.fit(model_now, train_now, test_now)
135
+
136
+
137
+ if __name__ == "__main__":
138
+ main()