MagicaNeko commited on
Commit
5d2f37a
·
verified ·
1 Parent(s): b4bbe05

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +198 -0
model.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import DataLoader, Dataset
5
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
6
+ import numpy as np
7
+ from PIL import Image
8
+ import random
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class CustomDataset(Dataset):
13
+
14
+ def __init__(self, red_dir, green_dir, blue_dir, nir_dir, mask_dir, pytorch=True):
15
+ super().__init__()
16
+ self.red_dir = red_dir
17
+ self.green_dir = green_dir
18
+ self.blue_dir = blue_dir
19
+ self.nir_dir = nir_dir
20
+ self.mask_dir = mask_dir
21
+
22
+ red_files = [f for f in self.red_dir.iterdir() if f.is_file()]
23
+ self.files = [self.combine_files(f) for f in red_files]
24
+ self.pytorch = pytorch
25
+
26
+
27
+ def combine_files(self, red_files: Path):
28
+ base_name = red_files.name
29
+
30
+ files = {
31
+ 'red': red_files,
32
+ 'green': self.green_dir / base_name.replace('red', 'green'),
33
+ 'blue': self.blue_dir / base_name.replace('red', 'blue'),
34
+ 'nir': self.nir_dir / base_name.replace('red', 'nir'),
35
+ 'mask': self.mask_dir / base_name.replace('red', 'gt'),
36
+ }
37
+
38
+ for key, path in files.items():
39
+ if not path.exists():
40
+ raise FileNotFoundError(f'Missing file: {path} for {red_files}')
41
+ return files
42
+
43
+
44
+ def __len__(self):
45
+ return len(self.files)
46
+
47
+
48
+ def open_as_array(self, idx, invert=False, nir_included=False):
49
+ rgb = np.stack([
50
+ np.array(Image.open(self.files[idx]['red'])),
51
+ np.array(Image.open(self.files[idx]['green'])),
52
+ np.array(Image.open(self.files[idx]['blue']))
53
+ ], axis=2)
54
+
55
+ if nir_included:
56
+ nir = np.array(Image.open(self.files[idx]['nir']))
57
+ nir = np.expand_dims(nir, 2)
58
+ rgb = np.concatenate([rgb, nir], axis=2)
59
+
60
+ if invert:
61
+ rgb = rgb.transpose((2, 0, 1))
62
+
63
+ raw_rgb = (rgb / np.iinfo(rgb.dtype).max)
64
+ return raw_rgb
65
+
66
+ def open_mask(self,idx, expand_dims=True):
67
+ raw_mask = np.array(Image.open(self.files[idx]['mask']))
68
+ raw_mask = np.where(raw_mask == 255, 1, 0) # Transform the mask into binary array where pixels with value 256(white) become 1(clouds), pixels with 0 or anything else becomes 0(not clouds)
69
+
70
+ return np.expand_dims(raw_mask, 0) if expand_dims else raw_mask
71
+
72
+ def __getitem__(self, idx):
73
+ X = torch.tensor(self.open_as_array(idx, invert=True, nir_included=True), dtype=torch.float32)
74
+ y = torch.tensor(self.open_mask(idx, expand_dims=True), dtype=torch.float32)
75
+
76
+ return X, y
77
+
78
+
79
+ class doubleConv(nn.Module):
80
+
81
+ def __init__(self, in_channels, out_channels):
82
+ super().__init__()
83
+
84
+ self.double_conv = nn.Sequential(
85
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
86
+ nn.BatchNorm2d(out_channels),
87
+ nn.ReLU(),
88
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
89
+ nn.BatchNorm2d(out_channels),
90
+ nn.ReLU()
91
+ )
92
+
93
+ def forward(self, x):
94
+ return self.double_conv(x)
95
+
96
+
97
+ class downSample(nn.Module):
98
+
99
+ def __init__(self, in_channels, out_channels):
100
+ super().__init__()
101
+
102
+ self.conv = doubleConv(in_channels, out_channels)
103
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
104
+
105
+ def forward(self, x):
106
+ down = self.conv(x)
107
+ p = self.pool(down)
108
+
109
+ return down, p
110
+
111
+
112
+ class upSample(nn.Module):
113
+ def __init__(self, in_channels, out_channels):
114
+ super().__init__()
115
+
116
+ self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
117
+ self.conv = doubleConv(out_channels * 2, out_channels)
118
+
119
+ def forward(self, x1, x2):
120
+ x1 = self.up(x1)
121
+ x = torch.cat([x1, x2], 1)
122
+ return self.conv(x)
123
+
124
+
125
+ class SpatialAttention(nn.Module):
126
+
127
+ def __init__(self):
128
+ super().__init__()
129
+ self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=3, padding=1)
130
+ self.sigmoid = nn.Sigmoid()
131
+
132
+ def forward(self, x):
133
+ avg_pooling = torch.mean(x, dim=1, keepdim=True)
134
+ max_pooling = torch.max(x, dim=1, keepdim=True)[0] # return on max values and not their indices
135
+ concat = torch.cat([avg_pooling, max_pooling], dim=1)
136
+ attention = self.conv(concat)
137
+ attention = self.sigmoid(attention)
138
+ output = x * attention
139
+ return output
140
+
141
+
142
+ class UNet(nn.Module):
143
+ def __init__(self, in_channels, num_classes):
144
+ super().__init__()
145
+
146
+ self.down_conv1 = downSample(in_channels, 32)
147
+ self.down_conv2 = downSample(32, 64)
148
+ self.down_conv3 = downSample(64, 128)
149
+
150
+ self.bottleneck = doubleConv(128, 256)
151
+ self.spatial_attention = SpatialAttention()
152
+
153
+ self.up_conv1 = upSample(256, 128)
154
+ self.up_conv2 = upSample(128, 64)
155
+ self.up_conv3 = upSample(64, 32)
156
+
157
+ self.out = nn.Conv2d(in_channels=32 , out_channels=num_classes, kernel_size=1)
158
+
159
+ def forward(self, x):
160
+
161
+ down1, p1 = self.down_conv1(x)
162
+ down2, p2 = self.down_conv2(p1)
163
+ down3, p3 = self.down_conv3(p2)
164
+
165
+ b = self.bottleneck(p3)
166
+ b = self.spatial_attention(b)
167
+
168
+ up1 = self.up_conv1(b, down3)
169
+ up2 = self.up_conv2(up1, down2)
170
+ up3 = self.up_conv3(up2, down1)
171
+
172
+ output = self.out(up3)
173
+ return output
174
+
175
+ def acc_fn(predb, yb):
176
+ preds = torch.sigmoid(predb) # Convert logits to probabilities
177
+ preds = (preds > 0.5).float() # Threshold at 0.5
178
+ return (preds == yb).float().mean() # Compare with ground truth
179
+
180
+ def calculate_metrics(y_true, y_pred):
181
+ TP = torch.sum((y_true == 1) & (y_pred == 1)).float()
182
+ TN = torch.sum((y_true == 0) & (y_pred == 0)).float()
183
+ FP = torch.sum((y_true == 0) & (y_pred == 1)).float()
184
+ FN = torch.sum((y_true == 1) & (y_pred == 0)).float()
185
+
186
+ jaccard = TP / (TP + FN + FP + 1e-10)
187
+ precision = TP / (TP + FP + 1e-10)
188
+ recall = TP / (TP + FN + 1e-10)
189
+ specificity = TN / (TN + FP + 1e-10)
190
+ overall_acc = (TP + TN) / (TP + TN + FP + FN + 1e-10)
191
+
192
+ return {
193
+ "Jaccard index": jaccard.item(),
194
+ "Precision": precision.item(),
195
+ "Recall": recall.item(),
196
+ "Specificity": specificity.item(),
197
+ "Overall Accuracy": overall_acc.item()
198
+ }