Anirudh Bhalekar commited on
Commit
a3f0d6c
·
1 Parent(s): 6e87dc7

added models and util folder

Browse files
Files changed (50) hide show
  1. models_Facies.py +397 -0
  2. models_Fault.py +327 -0
  3. util/__pycache__/datasets.cpython-311.pyc +0 -0
  4. util/__pycache__/datasets.cpython-312.pyc +0 -0
  5. util/__pycache__/datasets.cpython-36.pyc +0 -0
  6. util/__pycache__/datasets.cpython-37.pyc +0 -0
  7. util/__pycache__/lars.cpython-36.pyc +0 -0
  8. util/__pycache__/lr_decay.cpython-311.pyc +0 -0
  9. util/__pycache__/lr_decay.cpython-312.pyc +0 -0
  10. util/__pycache__/lr_decay.cpython-36.pyc +0 -0
  11. util/__pycache__/lr_decay.cpython-37.pyc +0 -0
  12. util/__pycache__/lr_sched.cpython-311.pyc +0 -0
  13. util/__pycache__/lr_sched.cpython-312.pyc +0 -0
  14. util/__pycache__/lr_sched.cpython-36.pyc +0 -0
  15. util/__pycache__/lr_sched.cpython-37.pyc +0 -0
  16. util/__pycache__/metrics.cpython-36.pyc +0 -0
  17. util/__pycache__/misc.cpython-311.pyc +0 -0
  18. util/__pycache__/misc.cpython-312.pyc +0 -0
  19. util/__pycache__/misc.cpython-36.pyc +0 -0
  20. util/__pycache__/misc.cpython-37.pyc +0 -0
  21. util/__pycache__/msssim.cpython-311.pyc +0 -0
  22. util/__pycache__/msssim.cpython-312.pyc +0 -0
  23. util/__pycache__/msssim.cpython-36.pyc +0 -0
  24. util/__pycache__/msssim.cpython-37.pyc +0 -0
  25. util/__pycache__/pos_embed.cpython-311.pyc +0 -0
  26. util/__pycache__/pos_embed.cpython-312.pyc +0 -0
  27. util/__pycache__/pos_embed.cpython-36.pyc +0 -0
  28. util/__pycache__/pos_embed.cpython-37.pyc +0 -0
  29. util/__pycache__/size_aware_batching.cpython-312.pyc +0 -0
  30. util/__pycache__/skeletonize.cpython-312.pyc +0 -0
  31. util/__pycache__/tools.cpython-311.pyc +0 -0
  32. util/__pycache__/tools.cpython-312.pyc +0 -0
  33. util/__pycache__/tools.cpython-36.pyc +0 -0
  34. util/__pycache__/tools.cpython-37.pyc +0 -0
  35. util/__pycache__/variable_pos_embed.cpython-312.pyc +0 -0
  36. util/crop.py +42 -0
  37. util/datasets.py +599 -0
  38. util/lars.py +47 -0
  39. util/lr_decay.py +76 -0
  40. util/lr_sched.py +21 -0
  41. util/metrics.py +90 -0
  42. util/misc.py +340 -0
  43. util/msssim.py +146 -0
  44. util/pos_embed.py +104 -0
  45. util/pos_embedtest.py +127 -0
  46. util/post_processing.py +305 -0
  47. util/size_aware_batching.py +251 -0
  48. util/skeletonize.py +486 -0
  49. util/tools.py +143 -0
  50. util/variable_pos_embed.py +143 -0
models_Facies.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # --------------------------------------------------------
11
+
12
+ from functools import partial
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import timm.models.vision_transformer
18
+ import numpy as np
19
+ from util.pos_embed import get_2d_sincos_pos_embed
20
+ from util.variable_pos_embed import interpolate_pos_embed_variable
21
+
22
+
23
+ class FlexiblePatchEmbed(nn.Module):
24
+ """ 2D Image to Patch Embedding that handles variable input sizes """
25
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, bias=True):
26
+ super().__init__()
27
+ self.img_size = img_size
28
+ self.patch_size = patch_size
29
+ self.in_chans = in_chans
30
+ self.embed_dim = embed_dim
31
+
32
+ self.num_patches = (img_size // patch_size) ** 2 # default number of patches
33
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
34
+
35
+ def forward(self, x):
36
+ B, C, H, W = x.shape
37
+ # Calculate number of patches dynamically
38
+ self.num_patches = (H // self.patch_size) * (W // self.patch_size)
39
+ x = self.proj(x).flatten(2).transpose(1, 2) # BCHW -> BNC
40
+ return x
41
+
42
+ class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
43
+ """ Vision Transformer with support for variable image sizes and adaptive positional embeddings
44
+ """
45
+ def __init__(self, global_pool=False, **kwargs):
46
+ super(VisionTransformer, self).__init__(**kwargs)
47
+
48
+ self.global_pool = global_pool
49
+ self.decoder = VIT_MLAHead(mla_channels=self.embed_dim,num_classes=self.num_classes)
50
+
51
+ self.segmentation_head = SegmentationHead(
52
+ in_channels=16,
53
+ out_channels=self.num_classes,
54
+ kernel_size=3,
55
+ )
56
+ if self.global_pool:
57
+ norm_layer = kwargs['norm_layer']
58
+ embed_dim = kwargs['embed_dim']
59
+ self.fc_norm = norm_layer(embed_dim)
60
+ del self.norm # remove the original norm
61
+
62
+ def interpolate_pos_encoding(self, x, h, w):
63
+ """
64
+ Interpolate positional embeddings for arbitrary input sizes
65
+ """
66
+ npatch = x.shape[1] - 1 # subtract 1 for cls token
67
+ N = self.pos_embed.shape[1] - 1 # original number of patches
68
+
69
+ if npatch == N and h == w:
70
+ return self.pos_embed
71
+
72
+ # Use the new variable position embedding utility
73
+ return interpolate_pos_embed_variable(self.pos_embed, h, w, cls_token=True)
74
+
75
+ def forward_features(self, x):
76
+ B, C, H, W = x.shape
77
+
78
+ # Handle padding for non-16-divisible images
79
+ patch_size = self.patch_embed.patch_size
80
+ pad_h = (patch_size - H % patch_size) % patch_size
81
+ pad_w = (patch_size - W % patch_size) % patch_size
82
+
83
+ if pad_h > 0 or pad_w > 0:
84
+ x = F.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
85
+ H_padded, W_padded = H + pad_h, W + pad_w
86
+ else:
87
+ H_padded, W_padded = H, W
88
+
89
+ # Extract patches
90
+ x = self.patch_embed(x)
91
+ _H, _W = H_padded // patch_size, W_padded // patch_size
92
+
93
+ # Add class token
94
+ cls_tokens = self.cls_token.expand(B, -1, -1)
95
+ x = torch.cat((cls_tokens, x), dim=1)
96
+
97
+ # Add interpolated positional embeddings
98
+ pos_embed = self.interpolate_pos_encoding(x, _H, _W)
99
+ x = x + pos_embed
100
+ x = self.pos_drop(x)
101
+
102
+ featureskip = []
103
+ featureskipnum = 1
104
+ for blk in self.blocks:
105
+ x = blk(x)
106
+ if featureskipnum % (len(self.blocks) // 4) == 0:
107
+ featureskip.append(x[:, 1:, :]) # exclude cls token
108
+ featureskipnum += 1
109
+
110
+ # Pass original dimensions for proper reconstruction
111
+ x = self.decoder(featureskip[0], featureskip[1], featureskip[2], featureskip[3],
112
+ h=_H, w=_W, target_h=H, target_w=W)
113
+ return x
114
+
115
+ def forward(self, x):
116
+ x = self.forward_features(x)
117
+ return x
118
+
119
+ class Conv2dReLU(nn.Sequential):
120
+ def __init__(
121
+ self,
122
+ in_channels,
123
+ out_channels,
124
+ kernel_size,
125
+ padding=0,
126
+ stride=1,
127
+ use_batchnorm=True,
128
+ ):
129
+ conv = nn.Conv2d(
130
+ in_channels,
131
+ out_channels,
132
+ kernel_size,
133
+ stride=stride,
134
+ padding=padding,
135
+ bias=not (use_batchnorm),
136
+ )
137
+ relu = nn.ReLU(inplace=True)
138
+
139
+ bn = nn.BatchNorm2d(out_channels)
140
+
141
+ super(Conv2dReLU, self).__init__(conv, bn, relu)
142
+
143
+
144
+ class DecoderBlock(nn.Module):
145
+ def __init__(
146
+ self,
147
+ in_channels,
148
+ out_channels,
149
+ skip_channels=0,
150
+ use_batchnorm=True,
151
+ ):
152
+ super().__init__()
153
+ self.conv1 = Conv2dReLU(
154
+ in_channels + skip_channels,
155
+ out_channels,
156
+ kernel_size=3,
157
+ padding=1,
158
+ use_batchnorm=use_batchnorm,
159
+ )
160
+ self.conv2 = Conv2dReLU(
161
+ out_channels,
162
+ out_channels,
163
+ kernel_size=3,
164
+ padding=1,
165
+ use_batchnorm=use_batchnorm,
166
+ )
167
+ self.up = nn.UpsamplingBilinear2d(scale_factor=2)
168
+
169
+ def forward(self, x, skip=None):
170
+ # print(x.shape,skip.shape)
171
+ if skip is not None:
172
+ x = torch.cat([x, skip], dim=1)
173
+ x = self.up(x)
174
+ x = self.conv1(x)
175
+ x = self.conv2(x)
176
+ return x
177
+
178
+
179
+ class SegmentationHead(nn.Sequential):
180
+
181
+ def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
182
+ conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
183
+ upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
184
+ super().__init__(conv2d, upsampling)
185
+
186
+
187
+ class DecoderCup(nn.Module):
188
+ def __init__(self):
189
+ super().__init__()
190
+ # self.config = config
191
+ head_channels = 512
192
+ self.conv_more = Conv2dReLU(
193
+ 1024,
194
+ head_channels,
195
+ kernel_size=3,
196
+ padding=1,
197
+ use_batchnorm=True,
198
+ )
199
+
200
+ decoder_channels = (256,128,64,16)
201
+
202
+
203
+ in_channels = [head_channels] + list(decoder_channels[:-1])
204
+ out_channels = decoder_channels
205
+
206
+ # if self.config.n_skip != 0:
207
+ # skip_channels = self.config.skip_channels
208
+ # for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip
209
+ # skip_channels[3-i]=0
210
+ # else:
211
+ # skip_channels=[0,0,0,0]
212
+ skip_channels=[512,256,128,64]
213
+ self.conv_feature1 = Conv2dReLU(1024,skip_channels[0],kernel_size=3,padding=1,use_batchnorm=True)
214
+ self.conv_feature2 = Conv2dReLU(1024,skip_channels[1],kernel_size=3,padding=1,use_batchnorm=True)
215
+ self.up2 = nn.UpsamplingBilinear2d(scale_factor=2)
216
+ self.conv_feature3 = Conv2dReLU(1024,skip_channels[2],kernel_size=3,padding=1,use_batchnorm=True)
217
+ self.up3 = nn.UpsamplingBilinear2d(scale_factor=4)
218
+ self.conv_feature4 = Conv2dReLU(1024,skip_channels[3],kernel_size=3,padding=1,use_batchnorm=True)
219
+ self.up4 = nn.UpsamplingBilinear2d(scale_factor=8)
220
+
221
+ # skip_channels=[128,64,32,8]
222
+ blocks = [
223
+ DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
224
+ ]
225
+ self.blocks = nn.ModuleList(blocks)
226
+
227
+ def TransShape(self,x,head_channels = 512,up=0):
228
+ B, n_patch, hidden = x.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
229
+
230
+ h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
231
+ x = x.permute(0, 2, 1)
232
+ x = x.contiguous().view(B, hidden, h, w)
233
+ if up==0:
234
+ x = self.conv_feature1(x)
235
+ elif up==1:
236
+ x = self.conv_feature2(x)
237
+ x = self.up2(x)
238
+ elif up==2:
239
+ x = self.conv_feature3(x)
240
+ x = self.up3(x)
241
+ elif up==3:
242
+ x = self.conv_feature4(x)
243
+ x = self.up4(x)
244
+ return x
245
+
246
+ def forward(self, hidden_states, features=None):
247
+ B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
248
+ h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
249
+ x = hidden_states.permute(0, 2, 1)
250
+ x = x.contiguous().view(B, hidden, h, w)
251
+ x = self.conv_more(x)
252
+ skip_channels=[512,256,128,64]
253
+ for i, decoder_block in enumerate(self.blocks):
254
+ if features is not None:
255
+ skip = self.TransShape(features[i],head_channels=skip_channels[i],up=i)
256
+ else:
257
+ skip = None
258
+ x = decoder_block(x, skip=skip)
259
+ return x
260
+
261
+
262
+ class MLAHead(nn.Module):
263
+ def __init__(self, mla_channels=256, mlahead_channels=128, norm_cfg=None):
264
+ super(MLAHead, self).__init__()
265
+ self.head2 = nn.Sequential(nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False),
266
+ nn.BatchNorm2d(mlahead_channels), nn.ReLU(),
267
+ nn.Conv2d(
268
+ mlahead_channels, mlahead_channels, 3, padding=1, bias=False),
269
+ nn.BatchNorm2d(mlahead_channels), nn.ReLU())
270
+ self.head3 = nn.Sequential(nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False),
271
+ nn.BatchNorm2d(mlahead_channels), nn.ReLU(),
272
+ nn.Conv2d(
273
+ mlahead_channels, mlahead_channels, 3, padding=1, bias=False),
274
+ nn.BatchNorm2d(mlahead_channels), nn.ReLU())
275
+ self.head4 = nn.Sequential(nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False),
276
+ nn.BatchNorm2d(mlahead_channels), nn.ReLU(),
277
+ nn.Conv2d(
278
+ mlahead_channels, mlahead_channels, 3, padding=1, bias=False),
279
+ nn.BatchNorm2d(mlahead_channels), nn.ReLU())
280
+ self.head5 = nn.Sequential(nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False),
281
+ nn.BatchNorm2d(mlahead_channels), nn.ReLU(),
282
+ nn.Conv2d(
283
+ mlahead_channels, mlahead_channels, 3, padding=1, bias=False),
284
+ nn.BatchNorm2d(mlahead_channels), nn.ReLU())
285
+
286
+ def forward(self, mla_p2, mla_p3, mla_p4, mla_p5):
287
+ head2 = F.interpolate(self.head2(
288
+ mla_p2), (4*mla_p2.shape[-2],4*mla_p2.shape[-1]), mode='bilinear', align_corners=True)
289
+ head3 = F.interpolate(self.head3(
290
+ mla_p3), (4*mla_p3.shape[-2],4*mla_p3.shape[-1]), mode='bilinear', align_corners=True)
291
+ head4 = F.interpolate(self.head4(
292
+ mla_p4), (4*mla_p4.shape[-2],4*mla_p4.shape[-1]), mode='bilinear', align_corners=True)
293
+ head5 = F.interpolate(self.head5(
294
+ mla_p5), (4*mla_p5.shape[-2],4*mla_p5.shape[-1]), mode='bilinear', align_corners=True)
295
+ return torch.cat([head2, head3, head4, head5], dim=1)
296
+
297
+
298
+ class VIT_MLAHead(nn.Module):
299
+ """ Vision Transformer with support for patch or hybrid CNN input stage
300
+ """
301
+
302
+ def __init__(self, img_size=768, mla_channels=256, mlahead_channels=128, num_classes=6,
303
+ norm_layer=nn.BatchNorm2d, norm_cfg=None, **kwargs):
304
+ super(VIT_MLAHead, self).__init__(**kwargs)
305
+ self.img_size = img_size
306
+ self.norm_cfg = norm_cfg
307
+ self.mla_channels = mla_channels
308
+ self.BatchNorm = norm_layer
309
+ self.mlahead_channels = mlahead_channels
310
+ self.num_classes = num_classes
311
+ self.mlahead = MLAHead(mla_channels=self.mla_channels,
312
+ mlahead_channels=self.mlahead_channels, norm_cfg=self.norm_cfg)
313
+ self.cls = nn.Conv2d(4 * self.mlahead_channels,
314
+ self.num_classes, 3, padding=1)
315
+
316
+ def forward(self, x1, x2, x3, x4, h=14, w=14, target_h=None, target_w=None):
317
+ B, n_patch, hidden = x1.size()
318
+ if h == w:
319
+ h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
320
+
321
+ # Reshape all feature maps
322
+ x1 = x1.permute(0, 2, 1).contiguous().view(B, hidden, h, w)
323
+ x2 = x2.permute(0, 2, 1).contiguous().view(B, hidden, h, w)
324
+ x3 = x3.permute(0, 2, 1).contiguous().view(B, hidden, h, w)
325
+ x4 = x4.permute(0, 2, 1).contiguous().view(B, hidden, h, w)
326
+
327
+ # Apply MLA head
328
+ x = self.mlahead(x1, x2, x3, x4)
329
+ x = self.cls(x)
330
+
331
+ # Calculate target size - if original image wasn't patch-size divisible
332
+ patch_size = 16 # assuming patch size of 16
333
+ if target_h is not None and target_w is not None:
334
+ target_size = (target_h, target_w)
335
+ else:
336
+ target_size = (h * patch_size, w * patch_size)
337
+
338
+ # Interpolate to target size
339
+ x = F.interpolate(x, size=target_size, mode='bilinear', align_corners=True)
340
+ return x
341
+
342
+
343
+ def mae_vit_small_patch16(**kwargs):
344
+ model = VisionTransformer(
345
+ patch_size=16, embed_dim=768, depth=6, num_heads=12, mlp_ratio=4, qkv_bias=True,
346
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
347
+ # Replace with flexible patch embedding
348
+ model.patch_embed = FlexiblePatchEmbed(
349
+ img_size=kwargs.get('img_size', 224),
350
+ patch_size=16,
351
+ in_chans=kwargs.get('in_chans', 3),
352
+ embed_dim=768
353
+ )
354
+ return model
355
+
356
+ def vit_base_patch16(**kwargs):
357
+ model = VisionTransformer(
358
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
359
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
360
+ # Replace with flexible patch embedding
361
+ model.patch_embed = FlexiblePatchEmbed(
362
+ img_size=kwargs.get('img_size', 224),
363
+ patch_size=16,
364
+ in_chans=kwargs.get('in_chans', 3),
365
+ embed_dim=768
366
+ )
367
+ return model
368
+
369
+
370
+ def vit_large_patch16(**kwargs):
371
+ model = VisionTransformer(
372
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
373
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
374
+ # Replace with flexible patch embedding
375
+ model.patch_embed = FlexiblePatchEmbed(
376
+ img_size=kwargs.get('img_size', 224),
377
+ patch_size=16,
378
+ in_chans=kwargs.get('in_chans', 3),
379
+ embed_dim=1024
380
+ )
381
+ return model
382
+
383
+
384
+ def vit_huge_patch14(**kwargs):
385
+ model = VisionTransformer(
386
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
387
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
388
+ # Replace with flexible patch embedding
389
+ model.patch_embed = FlexiblePatchEmbed(
390
+ img_size=kwargs.get('img_size', 224),
391
+ patch_size=14,
392
+ in_chans=kwargs.get('in_chans', 3),
393
+ embed_dim=1280
394
+ )
395
+ return model
396
+
397
+
models_Fault.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # --------------------------------------------------------
11
+
12
+
13
+ from functools import partial
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import timm.models.vision_transformer
19
+ import numpy as np
20
+ from util.msssim import MSSSIM
21
+ from util.pos_embed import get_2d_sincos_pos_embed
22
+ from util.variable_pos_embed import interpolate_pos_embed_variable
23
+
24
+
25
+ class FlexiblePatchEmbed(nn.Module):
26
+ """ 2D Image to Patch Embedding that handles variable input sizes """
27
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, bias=True):
28
+ super().__init__()
29
+ self.img_size = img_size
30
+ self.patch_size = patch_size
31
+ self.in_chans = in_chans
32
+ self.embed_dim = embed_dim
33
+
34
+ self.num_patches = (img_size // patch_size) ** 2 # default number of patches
35
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
36
+
37
+ def forward(self, x):
38
+ B, C, H, W = x.shape
39
+ # Calculate number of patches dynamically
40
+ self.num_patches = (H // self.patch_size) * (W // self.patch_size)
41
+ x = self.proj(x).flatten(2).transpose(1, 2) # BCHW -> BNC
42
+ return x
43
+
44
+
45
+ class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
46
+ """ Vision Transformer with support for global average pooling
47
+ """
48
+ def __init__(self, global_pool=False,**kwargs):
49
+ super(VisionTransformer, self).__init__(**kwargs)
50
+
51
+ self.global_pool = global_pool
52
+ self.decoder = DecoderCup(in_channels=[self.embed_dim,256,128,64])
53
+
54
+ self.segmentation_head = SegmentationHead(
55
+ in_channels=64,
56
+ out_channels=self.num_classes,
57
+ kernel_size=1
58
+ )
59
+ if self.global_pool:
60
+ norm_layer = kwargs['norm_layer']
61
+ embed_dim = kwargs['embed_dim']
62
+ self.fc_norm = norm_layer(embed_dim)
63
+ del self.norm # remove the original norm
64
+
65
+ def interpolate_pos_encoding(self, x, h, w):
66
+ """
67
+ Interpolate positional embeddings for arbitrary input sizes
68
+ """
69
+ npatch = x.shape[1] - 1 # subtract 1 for cls token
70
+ N = self.pos_embed.shape[1] - 1 # original number of patches
71
+
72
+ if npatch == N and h == w:
73
+ return self.pos_embed
74
+
75
+ # Use the new variable position embedding utility
76
+ return interpolate_pos_embed_variable(self.pos_embed, h, w, cls_token=True)
77
+
78
+
79
+ def generate_mask(self,input_tensor, ratio):
80
+ mask = torch.zeros_like(input_tensor)
81
+ indices = torch.randperm(mask.size(3)//16)[:int(mask.size(3)//16 * ratio)]
82
+ sorted_indices = torch.sort(indices)[0]
83
+ for i in range(0, len(sorted_indices)):
84
+ mask[:, :, :, sorted_indices[i]*16:(sorted_indices[i]+1)*16] = 1
85
+ return mask
86
+
87
+ def forward_features(self, x):
88
+ B,C,H,W = x.shape
89
+
90
+ # Handle padding for non-16-divisible images
91
+ patch_size = self.patch_embed.patch_size
92
+ pad_h = (patch_size - H % patch_size) % patch_size
93
+ pad_w = (patch_size - W % patch_size) % patch_size
94
+
95
+ if pad_h > 0 or pad_w > 0:
96
+ x = F.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
97
+ H_padded, W_padded = H + pad_h, W + pad_w
98
+ else:
99
+ H_padded, W_padded = H, W
100
+
101
+ img = x
102
+ x = self.patch_embed(x)
103
+
104
+ _H, _W = H_padded // patch_size, W_padded // patch_size
105
+
106
+ # Add class token
107
+ cls_tokens = self.cls_token.expand(B, -1, -1)
108
+ x = torch.cat((cls_tokens, x), dim=1)
109
+
110
+ # Add interpolated positional embeddings
111
+ pos_embed = self.interpolate_pos_encoding(x, _H, _W)
112
+ x = x + pos_embed
113
+ x = self.pos_drop(x)
114
+
115
+ for blk in self.blocks:
116
+ x = blk(x)
117
+ x = self.norm(x)
118
+
119
+ x = self.decoder(x[:, 1:, :], img)
120
+ x = self.segmentation_head(x)
121
+ return x
122
+
123
+ def forward(self, x):
124
+
125
+ x = self.forward_features(x)
126
+
127
+ return x
128
+
129
+ def inference(self, x):
130
+ x = self.forward_features(x)
131
+ x = F.softmax(x, dim=1)
132
+
133
+ return x
134
+
135
+ class Conv2dReLU(nn.Sequential):
136
+ def __init__(
137
+ self,
138
+ in_channels,
139
+ out_channels,
140
+ kernel_size,
141
+ padding=0,
142
+ stride=1,
143
+ use_batchnorm=True,
144
+ ):
145
+ conv = nn.Conv2d(
146
+ in_channels,
147
+ out_channels,
148
+ kernel_size,
149
+ stride=stride,
150
+ padding=padding,
151
+ bias=not (use_batchnorm),
152
+ )
153
+ relu = nn.ReLU(inplace=True)
154
+
155
+ bn = nn.BatchNorm2d(out_channels)
156
+
157
+ super(Conv2dReLU, self).__init__(conv, bn, relu)
158
+
159
+
160
+ class DecoderBlock(nn.Module):
161
+ def __init__(
162
+ self,
163
+ in_channels,
164
+ out_channels,
165
+ skip_channels=0,
166
+ use_batchnorm=True,
167
+ ):
168
+ super().__init__()
169
+ self.conv1 = Conv2dReLU(
170
+ in_channels + skip_channels,
171
+ out_channels,
172
+ kernel_size=3,
173
+ padding=1,
174
+ use_batchnorm=use_batchnorm,
175
+ )
176
+ self.conv2 = Conv2dReLU(
177
+ out_channels,
178
+ out_channels,
179
+ kernel_size=3,
180
+ padding=1,
181
+ use_batchnorm=use_batchnorm,
182
+ )
183
+ self.up = nn.UpsamplingBilinear2d(scale_factor=2)
184
+
185
+ def forward(self, x, skip=None):
186
+ x = self.up(x)
187
+ if skip is not None:
188
+ x = torch.cat([x, skip], dim=1)
189
+ x = self.conv1(x)
190
+ x = self.conv2(x)
191
+ return x
192
+
193
+
194
+ class SegmentationHead(nn.Sequential):
195
+
196
+ def __init__(self, in_channels, out_channels, kernel_size=1, upsampling=1):
197
+ conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=0)
198
+ upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
199
+ super().__init__(conv2d, upsampling)
200
+
201
+
202
+ class DecoderCup(nn.Module):
203
+ def __init__(self,in_channels=[1024,256,128,64]):
204
+ super().__init__()
205
+ head_channels = 512
206
+ self.conv_more = Conv2dReLU(
207
+ 1,
208
+ 32,
209
+ kernel_size=3,
210
+ padding=1,
211
+ use_batchnorm=True,
212
+ )
213
+ skip_channels=[0,0,0,32]
214
+ out_channels=[256,128,64,64]
215
+ blocks = [
216
+ DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
217
+ ]
218
+ self.blocks = nn.ModuleList(blocks)
219
+
220
+ def forward(self, hidden_states, img, features=None):
221
+ B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
222
+ h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
223
+ x = hidden_states.permute(0, 2, 1)
224
+ x = x.contiguous().view(B, hidden, h, w)
225
+ skip_channels=[None,None,None,self.conv_more(img)]
226
+ for i, decoder_block in enumerate(self.blocks):
227
+ x = decoder_block(x, skip=skip_channels[i])
228
+ return x
229
+
230
+ def forward_loss(imgs, pred):
231
+ """
232
+ imgs: [N, 3, H, W]
233
+ pred: [N, L, p*p*3]
234
+ mask: [N, L], 0 is keep, 1 is remove,
235
+ """
236
+ loss1f = torch.nn.MSELoss()
237
+ loss1 = loss1f(imgs, pred)
238
+ loss2f = MSSSIM()
239
+ loss2 = loss2f(imgs, pred)
240
+ a = 0.5
241
+ loss = (1-a)*loss1+a*loss2
242
+ return loss
243
+
244
+
245
+ def weighted_cross_entropy(pred, target):
246
+ """
247
+ Compute the weighted cross entropy loss.
248
+ NEED VERIFICATION
249
+ """
250
+
251
+ # Function to compute weighted cross entropy loss
252
+ # target: [batch, channel, s, s]
253
+ # pred: [batch, channel, s, s]
254
+
255
+ #print('pred shape ', pred.shape)
256
+ #print('target shape ', target.shape)
257
+ #print('--------------')
258
+ #print('sums of pred', torch.sum(pred))
259
+ #print('sums of target', torch.sum(target))
260
+ # beta is the fraction of non-fault pixels in the target (i.e the zeroes in the target)
261
+ beta = torch.mean(target) # fraction of fault pixels
262
+ beta = 1 - beta # fraction of non-fault pixels
263
+ beta = torch.clamp(beta, min=0.01, max=0.99) # avoid division by zero
264
+
265
+ #print('beta', beta)
266
+
267
+ # Compute the weighted cross entropy loss
268
+ loss = -(beta * target * torch.log(pred + 1e-8) + (1-beta) * (1 - target) * torch.log(1 - pred + 1e-8))
269
+ return torch.mean(loss)
270
+
271
+
272
+ def mae_vit_small_patch16(**kwargs):
273
+ model = VisionTransformer(
274
+ patch_size=16, embed_dim=768, depth=6, num_heads=12, mlp_ratio=4, qkv_bias=True,
275
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
276
+ # Replace with flexible patch embedding
277
+ model.patch_embed = FlexiblePatchEmbed(
278
+ img_size=kwargs.get('img_size', 224),
279
+ patch_size=16,
280
+ in_chans=kwargs.get('in_chans', 3),
281
+ embed_dim=768
282
+ )
283
+ return model
284
+
285
+ def vit_base_patch16(**kwargs):
286
+ model = VisionTransformer(
287
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
288
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
289
+ # Replace with flexible patch embedding
290
+ model.patch_embed = FlexiblePatchEmbed(
291
+ img_size=kwargs.get('img_size', 224),
292
+ patch_size=16,
293
+ in_chans=kwargs.get('in_chans', 3),
294
+ embed_dim=768
295
+ )
296
+ return model
297
+
298
+
299
+ def vit_large_patch16(**kwargs):
300
+ model = VisionTransformer(
301
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
302
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
303
+ # Replace with flexible patch embedding
304
+ model.patch_embed = FlexiblePatchEmbed(
305
+ img_size=kwargs.get('img_size', 224),
306
+ patch_size=16,
307
+ in_chans=kwargs.get('in_chans', 3),
308
+ embed_dim=1024
309
+ )
310
+ return model
311
+
312
+
313
+ def vit_huge_patch14(**kwargs):
314
+ model = VisionTransformer(
315
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
316
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
317
+ # Replace with flexible patch embedding
318
+ model.patch_embed = FlexiblePatchEmbed(
319
+ img_size=kwargs.get('img_size', 224),
320
+ patch_size=14,
321
+ in_chans=kwargs.get('in_chans', 3),
322
+ embed_dim=1280
323
+ )
324
+ return model
325
+
326
+
327
+
util/__pycache__/datasets.cpython-311.pyc ADDED
Binary file (23.6 kB). View file
 
util/__pycache__/datasets.cpython-312.pyc ADDED
Binary file (32.1 kB). View file
 
util/__pycache__/datasets.cpython-36.pyc ADDED
Binary file (19.2 kB). View file
 
util/__pycache__/datasets.cpython-37.pyc ADDED
Binary file (19.6 kB). View file
 
util/__pycache__/lars.cpython-36.pyc ADDED
Binary file (1.34 kB). View file
 
util/__pycache__/lr_decay.cpython-311.pyc ADDED
Binary file (2.66 kB). View file
 
util/__pycache__/lr_decay.cpython-312.pyc ADDED
Binary file (2.39 kB). View file
 
util/__pycache__/lr_decay.cpython-36.pyc ADDED
Binary file (1.6 kB). View file
 
util/__pycache__/lr_decay.cpython-37.pyc ADDED
Binary file (1.59 kB). View file
 
util/__pycache__/lr_sched.cpython-311.pyc ADDED
Binary file (1.02 kB). View file
 
util/__pycache__/lr_sched.cpython-312.pyc ADDED
Binary file (1.12 kB). View file
 
util/__pycache__/lr_sched.cpython-36.pyc ADDED
Binary file (595 Bytes). View file
 
util/__pycache__/lr_sched.cpython-37.pyc ADDED
Binary file (599 Bytes). View file
 
util/__pycache__/metrics.cpython-36.pyc ADDED
Binary file (3.83 kB). View file
 
util/__pycache__/misc.cpython-311.pyc ADDED
Binary file (21.4 kB). View file
 
util/__pycache__/misc.cpython-312.pyc ADDED
Binary file (19.4 kB). View file
 
util/__pycache__/misc.cpython-36.pyc ADDED
Binary file (10.8 kB). View file
 
util/__pycache__/misc.cpython-37.pyc ADDED
Binary file (10.8 kB). View file
 
util/__pycache__/msssim.cpython-311.pyc ADDED
Binary file (8.9 kB). View file
 
util/__pycache__/msssim.cpython-312.pyc ADDED
Binary file (7.84 kB). View file
 
util/__pycache__/msssim.cpython-36.pyc ADDED
Binary file (4.51 kB). View file
 
util/__pycache__/msssim.cpython-37.pyc ADDED
Binary file (4.49 kB). View file
 
util/__pycache__/pos_embed.cpython-311.pyc ADDED
Binary file (4.35 kB). View file
 
util/__pycache__/pos_embed.cpython-312.pyc ADDED
Binary file (4.14 kB). View file
 
util/__pycache__/pos_embed.cpython-36.pyc ADDED
Binary file (2.43 kB). View file
 
util/__pycache__/pos_embed.cpython-37.pyc ADDED
Binary file (2.42 kB). View file
 
util/__pycache__/size_aware_batching.cpython-312.pyc ADDED
Binary file (10.6 kB). View file
 
util/__pycache__/skeletonize.cpython-312.pyc ADDED
Binary file (35.5 kB). View file
 
util/__pycache__/tools.cpython-311.pyc ADDED
Binary file (7.76 kB). View file
 
util/__pycache__/tools.cpython-312.pyc ADDED
Binary file (6.81 kB). View file
 
util/__pycache__/tools.cpython-36.pyc ADDED
Binary file (4.25 kB). View file
 
util/__pycache__/tools.cpython-37.pyc ADDED
Binary file (4.26 kB). View file
 
util/__pycache__/variable_pos_embed.cpython-312.pyc ADDED
Binary file (5.42 kB). View file
 
util/crop.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import torch
10
+
11
+ from torchvision import transforms
12
+ from torchvision.transforms import functional as F
13
+
14
+
15
+ class RandomResizedCrop(transforms.RandomResizedCrop):
16
+ """
17
+ RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
18
+ This may lead to results different with torchvision's version.
19
+ Following BYOL's TF code:
20
+ https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
21
+ """
22
+ @staticmethod
23
+ def get_params(img, scale, ratio):
24
+ width, height = F._get_image_size(img)
25
+ area = height * width
26
+
27
+ target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
28
+ log_ratio = torch.log(torch.tensor(ratio))
29
+ aspect_ratio = torch.exp(
30
+ torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
31
+ ).item()
32
+
33
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
34
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
35
+
36
+ w = min(w, width)
37
+ h = min(h, height)
38
+
39
+ i = torch.randint(0, height - h + 1, size=(1,)).item()
40
+ j = torch.randint(0, width - w + 1, size=(1,)).item()
41
+
42
+ return i, j, h, w
util/datasets.py ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # --------------------------------------------------------
10
+
11
+ import os
12
+ import PIL
13
+
14
+ import os, random, glob
15
+ import numpy as np
16
+ import torch
17
+ import torch.utils.data as data
18
+ import torchvision.transforms as transforms
19
+ from os.path import isfile, join
20
+ import segyio
21
+ from itertools import permutations
22
+
23
+ random.seed(42)
24
+
25
+ from torchvision import datasets, transforms
26
+
27
+ from timm.data import create_transform
28
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
29
+
30
+
31
+ def build_dataset(is_train, args):
32
+ transform = build_transform(is_train, args)
33
+
34
+ root = os.path.join(args.data_path, 'train' if is_train else 'val')
35
+ dataset = datasets.ImageFolder(root, transform=transform)
36
+
37
+ print(dataset)
38
+
39
+ return dataset
40
+
41
+
42
+ def build_transform(is_train, args):
43
+ mean = IMAGENET_DEFAULT_MEAN
44
+ std = IMAGENET_DEFAULT_STD
45
+ # train transform
46
+ if is_train:
47
+ # this should always dispatch to transforms_imagenet_train
48
+ transform = create_transform(
49
+ input_size=args.input_size,
50
+ is_training=True,
51
+ color_jitter=args.color_jitter,
52
+ auto_augment=args.aa,
53
+ interpolation='bicubic',
54
+ re_prob=args.reprob,
55
+ re_mode=args.remode,
56
+ re_count=args.recount,
57
+ mean=mean,
58
+ std=std,
59
+ )
60
+ return transform
61
+
62
+ # eval transform
63
+ t = []
64
+ if args.input_size <= 224:
65
+ crop_pct = 224 / 256
66
+ else:
67
+ crop_pct = 1.0
68
+ size = int(args.input_size / crop_pct)
69
+ t.append(
70
+ transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
71
+ )
72
+ t.append(transforms.CenterCrop(args.input_size))
73
+
74
+ t.append(transforms.ToTensor())
75
+ t.append(transforms.Normalize(mean, std))
76
+ return transforms.Compose(t)
77
+
78
+
79
+ ## pretrain
80
+ class SeismicSet(data.Dataset):
81
+
82
+ def __init__(self, path, input_size) -> None:
83
+ super().__init__()
84
+ # self.file_list = os.listdir(path)
85
+ # self.file_list = [os.path.join(path, f) for f in self.file_list]
86
+ self.get_file_list(path)
87
+ self.input_size = input_size
88
+ print(len(self.file_list))
89
+
90
+ def __len__(self) -> int:
91
+ return len(self.file_list)
92
+ # return 100000
93
+
94
+ def __getitem__(self, index):
95
+ d = np.fromfile(self.file_list[index], dtype=np.float32)
96
+ d = d.reshape(1, self.input_size, self.input_size)
97
+ d = (d - d.mean()) / (d.std()+1e-6)
98
+
99
+ # return to_transforms(d, self.input_size)
100
+ return d,torch.tensor([1])
101
+
102
+ def get_file_list(self, path):
103
+ dirs = [os.path.join(path, f) for f in os.listdir(path)]
104
+ self.file_list = dirs
105
+
106
+ # for ds in dirs:
107
+ # if os.path.isdir(ds):
108
+ # self.file_list += [os.path.join(ds, f) for f in os.listdir(ds)]
109
+
110
+ return random.shuffle(self.file_list)
111
+
112
+
113
+ def to_transforms(d, input_size):
114
+ t = transforms.Compose([
115
+ transforms.RandomResizedCrop(input_size,
116
+ scale=(0.2, 1.0),
117
+ interpolation=3), # 3 is bicubic
118
+ transforms.RandomHorizontalFlip(),
119
+ transforms.ToTensor()
120
+ ])
121
+
122
+ return t(d)
123
+
124
+
125
+
126
+ ### fintune
127
+ class FacesSet(data.Dataset):
128
+ # folder/train/data/**.dat, folder/train/label/**.dat
129
+ # folder/test/data/**.dat, folder/test/label/**.dat
130
+ def __init__(self,
131
+ folder,
132
+ shape=[768, 768],
133
+ is_train=True) -> None:
134
+ super().__init__()
135
+ self.shape = shape
136
+
137
+ # self.data_list = sorted(glob.glob(folder + 'seismic/*.dat'))
138
+ self.data_list = [folder +'seismic/'+ str(f)+'.dat' for f in range(117)]
139
+
140
+ n = len(self.data_list)
141
+ if is_train:
142
+ self.data_list = self.data_list[:100]
143
+ elif not is_train:
144
+ self.data_list = self.data_list[100:]
145
+ self.label_list = [
146
+ f.replace('/seismic/', '/label/') for f in self.data_list
147
+ ]
148
+
149
+ def __getitem__(self, index):
150
+ d = np.fromfile(self.data_list[index], np.float32)
151
+ d = d.reshape([1] + self.shape)
152
+ l = np.fromfile(self.label_list[index], np.float32).reshape(self.shape)-1
153
+ l = l.astype(int)
154
+ return torch.tensor(d), torch.tensor(l)
155
+
156
+
157
+ def __len__(self):
158
+ return len(self.data_list)
159
+
160
+
161
+
162
+ class SaltSet(data.Dataset):
163
+
164
+ def __init__(self,
165
+ folder,
166
+ shape=[224, 224],
167
+ is_train=True) -> None:
168
+ super().__init__()
169
+ self.shape = shape
170
+ self.data_list = [folder +'seismic/'+ str(f)+'.dat' for f in range(4000)]
171
+ n = len(self.data_list)
172
+ if is_train:
173
+ self.data_list = self.data_list[:3500]
174
+ elif not is_train:
175
+ self.data_list = self.data_list[3500:]
176
+ self.label_list = [
177
+ f.replace('/seismic/', '/label/') for f in self.data_list
178
+ ]
179
+
180
+ def __getitem__(self, index):
181
+ d = np.fromfile(self.data_list[index], np.float32)
182
+ d = d.reshape([1] + self.shape)
183
+ l = np.fromfile(self.label_list[index], np.float32).reshape(self.shape)
184
+ l = l.astype(int)
185
+ return torch.tensor(d), torch.tensor(l)
186
+ def __len__(self):
187
+ return len(self.data_list)
188
+
189
+
190
+ class InterpolationSet(data.Dataset):
191
+ # folder/train/data/**.dat, folder/train/label/**.dat
192
+ # folder/test/data/**.dat, folder/test/label/**.dat
193
+ def __init__(self,
194
+ folder,
195
+ shape=[224, 224],
196
+ is_train=True) -> None:
197
+ super().__init__()
198
+ self.shape = shape
199
+ self.data_list = [folder + str(f)+'.dat' for f in range(6000)]
200
+ n = len(self.data_list)
201
+ if is_train:
202
+ self.data_list = self.data_list
203
+ elif not is_train:
204
+ self.data_list = [folder+'U'+ + str(f)+'.dat' for f in range(2000,4000)]
205
+ self.label_list = self.data_list
206
+
207
+ def __getitem__(self, index):
208
+ d = np.fromfile(self.data_list[index], np.float32)
209
+ d = d.reshape([1] + self.shape)
210
+ return torch.tensor(d), torch.tensor(d)
211
+
212
+
213
+ def __len__(self):
214
+ return len(self.data_list)
215
+ # return 10000
216
+
217
+
218
+
219
+ class DenoiseSet(data.Dataset):
220
+ def __init__(self,
221
+ folder,
222
+ shape=[224, 224],
223
+ is_train=True) -> None:
224
+ super().__init__()
225
+ self.shape = shape
226
+ self.data_list = [folder+'seismic/'+ str(f)+'.dat' for f in range(2000)]
227
+ n = len(self.data_list)
228
+ if is_train:
229
+ self.data_list = self.data_list
230
+ self.label_list = [f.replace('/seismic/', '/label/') for f in self.data_list]
231
+ elif not is_train:
232
+ self.data_list = [folder+'field/'+ str(f)+'.dat' for f in range(4000)]
233
+ self.label_list = self.data_list
234
+
235
+ def __getitem__(self, index):
236
+ d = np.fromfile(self.data_list[index], np.float32)
237
+ d = d.reshape([1] + self.shape)
238
+ # d = (d - d.mean())/d.std()
239
+ l = np.fromfile(self.label_list[index], np.float32)
240
+ l = l.reshape([1] + self.shape)
241
+ # l = (l - d.mean())/l.std()
242
+ return torch.tensor(d), torch.tensor(l)
243
+
244
+
245
+ def __len__(self):
246
+ return len(self.data_list)
247
+
248
+
249
+ class ReflectSet(data.Dataset):
250
+ # folder/train/data/**.dat, folder/train/label/**.dat
251
+ # folder/test/data/**.dat, folder/test/label/**.dat
252
+ def __init__(self,
253
+ folder,
254
+ shape=[224, 224],
255
+ is_train=True) -> None:
256
+ super().__init__()
257
+ self.shape = shape
258
+ self.data_list = [folder+'seismic/'+ str(f)+'.dat' for f in range(2200)]
259
+
260
+
261
+
262
+ n = len(self.data_list)
263
+ if is_train:
264
+ self.data_list = self.data_list
265
+ self.label_list = [
266
+ f.replace('/seismic/', '/label/') for f in self.data_list
267
+ ]
268
+ elif not is_train:
269
+ self.data_list = [folder+'SEAMseismic/'+ str(f)+'.dat' for f in range(4000)]
270
+ self.label_list = [
271
+ f.replace('/SEAMseismic/', '/SEAMreflect/') for f in self.data_list
272
+ ]
273
+
274
+ def __getitem__(self, index):
275
+ d = np.fromfile(self.data_list[index], np.float32)
276
+ d = d- d.mean()
277
+ d = d/(d.std()+1e-6)
278
+ d = d.reshape([1] + self.shape)
279
+ l = np.fromfile(self.label_list[index], np.float32)
280
+ l = l-l.mean()
281
+ l = l/(l.std()+1e-6)
282
+ l = l.reshape([1] + self.shape)
283
+ return torch.tensor(d), torch.tensor(l)
284
+
285
+
286
+ def __len__(self):
287
+ return len(self.data_list)
288
+
289
+
290
+ class ThebeSet(data.Dataset):
291
+ def __init__(self, folder, shape=[224, 224], mode ='train') -> None:
292
+ super().__init__()
293
+
294
+ self.folder = folder
295
+ if not os.path.exists(folder):
296
+ raise FileNotFoundError(f"The folder {folder} does not exist.")
297
+
298
+ self.num_files = len(os.listdir(join(folder, 'fault')))
299
+ self.shape = shape
300
+ self.fault_list = [folder + '/fault/{i}.npy'.format(i=i) for i in range(1, self.num_files + 1)]
301
+ self.seis_list = [folder + '/seis/{i}.npy'.format(i=i) for i in range(1, self.num_files + 1)]
302
+
303
+ self.train_size = int(0.75 * self.num_files)
304
+ self.val_size = int(0.15 * self.num_files)
305
+ self.test_size = self.num_files - self.train_size - self.val_size
306
+
307
+ self.train_index = self.train_size
308
+ self.val_index = self.train_index + self.val_size
309
+
310
+ if mode == 'train':
311
+ self.fault_list = self.fault_list[:self.train_index]
312
+ self.seis_list = self.seis_list[:self.train_index]
313
+ elif mode == 'val':
314
+ self.fault_list = self.fault_list[self.train_index:self.val_index]
315
+ self.seis_list = self.seis_list[self.train_index:self.val_index]
316
+ elif mode == 'test':
317
+ self.fault_list = self.fault_list[self.val_index:]
318
+ self.seis_list = self.seis_list[self.val_index:]
319
+ else:
320
+ raise ValueError("Mode must be 'train', 'val', or 'test'.")
321
+
322
+ def __len__(self):
323
+ return len(self.fault_list)
324
+
325
+ def retrieve_patch(self, fault, seis):
326
+ # image will (probably) be of size [3174, 1537]
327
+ # return a patch of size [224, 224]
328
+
329
+ patch_height = self.shape[0]
330
+ patch_width = self.shape[1]
331
+
332
+ h, w = fault.shape
333
+ if h < patch_height or w < patch_width:
334
+ raise ValueError(f"Image dimensions must be at least {patch_height}x{patch_width}.")
335
+
336
+ top = random.randint(0, h - patch_height)
337
+ left = random.randint(0, w - patch_width)
338
+
339
+ return fault[top:top + patch_height, left:left + patch_width], seis[top:top + patch_height, left:left + patch_width]
340
+
341
+ def random_transform(self, fault, seis):
342
+ # Apply the same random transformations to the fault and seismic data
343
+ # Mirror the patch horizontally
344
+ if random.random() > 0.5:
345
+ fault = np.fliplr(fault)
346
+ seis = np.fliplr(seis)
347
+
348
+ # Mirror the patch vertically
349
+ if random.random() > 0.5:
350
+ fault = np.flipud(fault)
351
+ seis = np.flipud(seis)
352
+
353
+ return fault, seis
354
+
355
+ def __getitem__(self, index):
356
+ # need to see if we do normalization here (i.e. what data pre-treatement we do)
357
+ fault = np.load(self.fault_list[index])
358
+ seis = np.load(self.seis_list[index])
359
+
360
+ fault, seis = self.retrieve_patch(fault, seis)
361
+ fault, seis = self.random_transform(fault, seis)
362
+
363
+ seis = (seis - seis.mean()) / (seis.std() + 1e-6)
364
+
365
+ fault = torch.tensor(fault.copy(), dtype=torch.float32).unsqueeze(0)
366
+ seis = torch.tensor(seis.copy(), dtype=torch.float32).unsqueeze(0)
367
+
368
+ return seis, fault
369
+
370
+ class FSegSet(data.Dataset):
371
+ def __init__(self, folder, shape=[128, 128], mode ='train') -> None:
372
+ super().__init__()
373
+
374
+ self.folder = folder
375
+ if not os.path.exists(folder):
376
+ raise FileNotFoundError(f"The folder {folder} does not exist.")
377
+
378
+ self.shape = shape
379
+ self.mode = mode
380
+
381
+ if mode == 'train':
382
+ self.fault_path = join(self.folder, 'train/fault')
383
+ self.seis_path = join(self.folder, 'train/seis')
384
+ elif mode == 'val':
385
+ self.fault_path = join(self.folder, 'val/fault')
386
+ self.seis_path = join(self.folder, 'val/seis')
387
+ else:
388
+ raise ValueError("Mode must be 'train' or 'val'.")
389
+
390
+
391
+ self.fault_list = [join(self.fault_path, f) for f in os.listdir(self.fault_path) if f.endswith('.npy')]
392
+ self.seis_list = [join(self.seis_path, f) for f in os.listdir(self.seis_path) if f.endswith('.npy')]
393
+
394
+ def __len__(self):
395
+ return len(self.fault_list)
396
+
397
+ def __getitem__(self, index):
398
+
399
+ fault_img, seis_img = np.load(self.fault_list[index]), np.load(self.seis_list[index])
400
+ # These will be 128x128
401
+
402
+ seis_img = (seis_img - seis_img.mean()) / (seis_img.std() + 1e-6)
403
+
404
+ fault = torch.tensor(fault_img.copy(), dtype=torch.float32).unsqueeze(0)
405
+ seis = torch.tensor(seis_img.copy(), dtype=torch.float32).unsqueeze(0)
406
+
407
+ return seis, fault
408
+
409
+ class F3DFaciesSet(data.Dataset):
410
+ def __init__(self, folder, shape=[128, 128], mode='train', random_resize = False):
411
+ super().__init__()
412
+
413
+ self.folder = folder
414
+ if not os.path.exists(folder):
415
+ raise FileNotFoundError(f"The folder {folder} does not exist.")
416
+
417
+
418
+ self.seises = np.load(join(folder, "{}/seismic.npy".format(mode)))
419
+ self.labels = np.load(join(folder, "{}/labels.npy".format(mode)))
420
+ self.image_shape = shape
421
+
422
+ if mode == 'train':
423
+ self.size_categories = [
424
+ (401, 701),
425
+ (701, 255),
426
+ (401, 255)
427
+ ]
428
+ elif mode == 'val':
429
+ self.size_categories = [
430
+ (601, 200),
431
+ (200, 255),
432
+ (601, 255)
433
+ ]
434
+
435
+ elif mode == 'test':
436
+ self.size_categories = [
437
+ (701, 255),
438
+ (200, 701),
439
+ (200, 255)
440
+ ]
441
+
442
+ else:
443
+ raise ValueError("Mode must be 'train', 'val', or 'test'.")
444
+ def __len__(self):
445
+ # We will take cross sections along each dimension, so the length is the sum of all dimensions
446
+
447
+ return sum(self.seises.shape)
448
+
449
+ def random_transform(self, label, seis):
450
+ # Apply the same random transformations to the fault and seismic data
451
+ # Mirror the patch horizontally
452
+ if random.random() > 0.5:
453
+ label = np.fliplr(label)
454
+ seis = np.fliplr(seis)
455
+
456
+ # Mirror the patch vertically
457
+ if random.random() > 0.5:
458
+ label = np.flipud(label)
459
+ seis = np.flipud(seis)
460
+
461
+ return label, seis
462
+
463
+ def __getitem__(self, index):
464
+
465
+ m1, m2, m3 = self.seises.shape
466
+
467
+ if index < m1:
468
+ seis, label = self.seises[index, :, :], self.labels[index, :, :]
469
+ elif index < m1 + m2:
470
+ seis, label = self.seises[:, index - m1, :], self.labels[:, index - m1, :]
471
+ elif index < m1 + m2 + m3:
472
+ seis, label = self.seises[:, :, index - m1 - m2], self.labels[:, :, index - m1 - m2]
473
+ else:
474
+ raise IndexError("Index out of bounds")
475
+
476
+ seis, label = self.random_transform(seis, label)
477
+ seis = (seis - seis.mean()) / (seis.std() + 1e-6)
478
+
479
+ seis, label = torch.tensor(seis.copy(), dtype=torch.float32).unsqueeze(0), torch.tensor(label.copy(), dtype=torch.float32).unsqueeze(0)
480
+
481
+ # label is now shape [1, H, W]
482
+ # we want shape [6, H, W] with each slice being a binary mask depending on the int value of label
483
+ label = label.squeeze(0)
484
+ label = (label == torch.arange(6, device=label.device).view(6, 1, 1)).float()
485
+
486
+ return seis, label
487
+
488
+ class P3DFaciesSet(data.Dataset):
489
+ def __init__(self, folder, shape=[128, 128], mode='train', random_resize = False):
490
+ super().__init__()
491
+
492
+ self.folder = folder
493
+ if not os.path.exists(folder):
494
+ raise FileNotFoundError(f"The folder {folder} does not exist.")
495
+
496
+ self.random_resize = random_resize
497
+
498
+ # Validation set will be validation set from F3DSet
499
+ if mode == 'val': mode = 'train' # TEMPORARY SINCE P3D does not have labelled val set
500
+
501
+ self.mode = mode
502
+ self.image_shape = shape
503
+
504
+ self.s_path = join(folder, "{}/seismic.segy".format(mode))
505
+ self.l_path = join(folder, "{}/labels.segy".format(mode))
506
+
507
+ if mode != 'val':
508
+ with segyio.open(self.s_path, ignore_geometry=True) as seis_file:
509
+ self.seises = seis_file.trace.raw[:]
510
+
511
+ if self.mode in ['val', 'train']:
512
+ with segyio.open(self.l_path, ignore_geometry=True) as label_file:
513
+ self.labels = label_file.trace.raw[:]
514
+ else:
515
+ # Since the test files are unlabeled
516
+ self.labels = np.zeros_like(self.seises)
517
+ else:
518
+ f3d_file_path = "C:\\Users\\abhalekar\\Desktop\\DATASETS\\F3D_facies_DATASET"
519
+ self.seises = np.load(join(f3d_file_path, "val/seismic.npy"))
520
+ self.labels = np.load(join(f3d_file_path, "val/labels.npy"))
521
+
522
+ if mode == 'train':
523
+ m1, m2, m3 = 590, 782, 1006
524
+ elif mode == 'val':
525
+ m1, m2, m3 = 601, 200, 255
526
+ elif mode == 'test_1':
527
+ m1, m2, m3 = 841, 334, 1006
528
+ elif mode == 'test_2':
529
+ m1, m2, m3 = 251, 782, 1006
530
+ else:
531
+ raise ValueError("Mode must be 'train', 'test_2', 'val', or 'test_1'.")
532
+
533
+ self.size_categories = list(permutations([m1, m2, m3], 2))
534
+
535
+ self.seises = self.seises.reshape(m1, m2, m3)
536
+ self.labels = self.labels.reshape(m1, m2, m3)
537
+
538
+ def __len__(self):
539
+ # We will take cross sections along the first 2 dimensions ONLY
540
+ return self.seises.shape[0] + self.seises.shape[1]
541
+
542
+ def _random_transform(self, label, seis):
543
+ # Apply the same random transformations to the fault and seismic data
544
+ # Mirror the patch horizontally
545
+ if random.random() > 0.5:
546
+ label = np.fliplr(label)
547
+ seis = np.fliplr(seis)
548
+
549
+ # Mirror the patch vertically
550
+ if random.random() > 0.5:
551
+ label = np.flipud(label)
552
+ seis = np.flipud(seis)
553
+
554
+ # random rotation to 2D image label,seis
555
+ #r_int = random.randint(0, 3)
556
+ #label = np.rot90(label, r_int)
557
+ #seis = np.rot90(seis, r_int)
558
+
559
+ return label, seis
560
+
561
+ def _random_resize(self, label, seis, min_size = (256, 256)):
562
+ # Randomly resize the label and seismic data
563
+ r_height = random.randint(min_size[0], seis.shape[0])
564
+ r_width = random.randint(min_size[1], seis.shape[1])
565
+
566
+ r_pos_x = random.randint(0, seis.shape[0] - r_height)
567
+ r_pos_y = random.randint(0, seis.shape[1] - r_width)
568
+
569
+ label = label[r_pos_x:r_pos_x + r_height, r_pos_y:r_pos_y + r_width]
570
+ seis = seis[r_pos_x:r_pos_x + r_height, r_pos_y:r_pos_y + r_width]
571
+
572
+ return label, seis
573
+
574
+ def __getitem__(self, index):
575
+
576
+ m1, m2, m3 = self.seises.shape
577
+
578
+ if index < m1:
579
+ seis, label = self.seises[index, :, :], self.labels[index, :, :]
580
+ elif index < m1 + m2:
581
+ seis, label = self.seises[:, index - m1, :], self.labels[:, index - m1, :]
582
+ elif index < m1 + m2 + m3:
583
+ seis, label = self.seises[:, :, index - m1 - m2], self.labels[:, :, index - m1 - m2]
584
+ else:
585
+ raise IndexError("Index out of bounds")
586
+
587
+ seis, label = self._random_transform(seis, label)
588
+ if self.random_resize: seis, label = self._random_resize(seis, label)
589
+
590
+ seis = (seis - seis.mean()) / (seis.std() + 1e-6)
591
+
592
+ seis, label = torch.tensor(seis.copy(), dtype=torch.float32).unsqueeze(0), torch.tensor(label.copy(), dtype=torch.float32).unsqueeze(0)
593
+
594
+ # label is now shape [1, H, W]
595
+ # we want shape [6, H, W] with each slice being a binary mask depending on the int value of label
596
+ label = label.squeeze(0)
597
+ label = (label == torch.arange(1, 7, device=label.device).view(6, 1, 1)).float()
598
+
599
+ return seis, label
util/lars.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # LARS optimizer, implementation from MoCo v3:
8
+ # https://github.com/facebookresearch/moco-v3
9
+ # --------------------------------------------------------
10
+
11
+ import torch
12
+
13
+
14
+ class LARS(torch.optim.Optimizer):
15
+ """
16
+ LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
17
+ """
18
+ def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
19
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
20
+ super().__init__(params, defaults)
21
+
22
+ @torch.no_grad()
23
+ def step(self):
24
+ for g in self.param_groups:
25
+ for p in g['params']:
26
+ dp = p.grad
27
+
28
+ if dp is None:
29
+ continue
30
+
31
+ if p.ndim > 1: # if not normalization gamma/beta or bias
32
+ dp = dp.add(p, alpha=g['weight_decay'])
33
+ param_norm = torch.norm(p)
34
+ update_norm = torch.norm(dp)
35
+ one = torch.ones_like(param_norm)
36
+ q = torch.where(param_norm > 0.,
37
+ torch.where(update_norm > 0,
38
+ (g['trust_coefficient'] * param_norm / update_norm), one),
39
+ one)
40
+ dp = dp.mul(q)
41
+
42
+ param_state = self.state[p]
43
+ if 'mu' not in param_state:
44
+ param_state['mu'] = torch.zeros_like(p)
45
+ mu = param_state['mu']
46
+ mu.mul_(g['momentum']).add_(dp)
47
+ p.add_(mu, alpha=-g['lr'])
util/lr_decay.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # ELECTRA https://github.com/google-research/electra
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+
12
+ import json
13
+
14
+
15
+ def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
16
+ """
17
+ Parameter groups for layer-wise lr decay
18
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
19
+ """
20
+ param_group_names = {}
21
+ param_groups = {}
22
+
23
+ num_layers = len(model.blocks) + 1
24
+
25
+ layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
26
+
27
+ for n, p in model.named_parameters():
28
+ if not p.requires_grad:
29
+ continue
30
+
31
+ # no decay: all 1D parameters and model specific ones
32
+ if p.ndim == 1 or n in no_weight_decay_list:
33
+ g_decay = "no_decay"
34
+ this_decay = 0.
35
+ else:
36
+ g_decay = "decay"
37
+ this_decay = weight_decay
38
+
39
+ layer_id = get_layer_id_for_vit(n, num_layers)
40
+ group_name = "layer_%d_%s" % (layer_id, g_decay)
41
+
42
+ if group_name not in param_group_names:
43
+ this_scale = layer_scales[layer_id]
44
+
45
+ param_group_names[group_name] = {
46
+ "lr_scale": this_scale,
47
+ "weight_decay": this_decay,
48
+ "params": [],
49
+ }
50
+ param_groups[group_name] = {
51
+ "lr_scale": this_scale,
52
+ "weight_decay": this_decay,
53
+ "params": [],
54
+ }
55
+
56
+ param_group_names[group_name]["params"].append(n)
57
+ param_groups[group_name]["params"].append(p)
58
+
59
+ # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
60
+
61
+ return list(param_groups.values())
62
+
63
+
64
+ def get_layer_id_for_vit(name, num_layers):
65
+ """
66
+ Assign a parameter with its layer id
67
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
68
+ """
69
+ if name in ['cls_token', 'pos_embed']:
70
+ return 0
71
+ elif name.startswith('patch_embed'):
72
+ return 0
73
+ elif name.startswith('blocks'):
74
+ return int(name.split('.')[1]) + 1
75
+ else:
76
+ return num_layers
util/lr_sched.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ def adjust_learning_rate(optimizer, epoch, args):
10
+ """Decay the learning rate with half-cycle cosine after warmup"""
11
+ if epoch < args.warmup_epochs:
12
+ lr = args.lr * epoch / args.warmup_epochs
13
+ else:
14
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
15
+ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
16
+ for param_group in optimizer.param_groups:
17
+ if "lr_scale" in param_group:
18
+ param_group["lr"] = lr * param_group["lr_scale"]
19
+ else:
20
+ param_group["lr"] = lr
21
+ return lr
util/metrics.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ refer to https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/utils/metrics.py
4
+ """
5
+ import numpy as np
6
+
7
+ __all__ = ['SegmentationMetric']
8
+
9
+ """
10
+ confusionMetric # 注意:此处横着代表预测值,竖着代表真实值,与之前介绍的相反
11
+ P\L P N
12
+ P TP FP
13
+ N FN TN
14
+ """
15
+
16
+
17
+ class SegmentationMetric(object):
18
+ def __init__(self, numClass):
19
+ self.numClass = numClass
20
+ self.confusionMatrix = np.zeros((self.numClass,) * 2) # 混淆矩阵(空)
21
+
22
+ def pixelAccuracy(self):
23
+ # return all class overall pixel accuracy 正确的像素占总像素的比例
24
+ # PA = acc = (TP + TN) / (TP + TN + FP + TN)
25
+ acc = np.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum()
26
+ return acc
27
+
28
+ def classPixelAccuracy(self):
29
+ # return each category pixel accuracy(A more accurate way to call it precision)
30
+ # acc = (TP) / TP + FP
31
+ classAcc = np.diag(self.confusionMatrix) / self.confusionMatrix.sum(axis=1)
32
+ return classAcc # 返回的是一个列表值,如:[0.90, 0.80, 0.96],表示类别1 2 3各类别的预测准确率
33
+
34
+ def meanPixelAccuracy(self):
35
+ """
36
+ Mean Pixel Accuracy(MPA,均像素精度):是PA的一种简单提升,计算每个类内被正确分类像素数的比例,之后求所有类的平均。
37
+ :return:
38
+ """
39
+ classAcc = self.classPixelAccuracy()
40
+ meanAcc = np.nanmean(classAcc) # np.nanmean 求平均值,nan表示遇到Nan类型,其值取为0
41
+ return meanAcc # 返回单个值,如:np.nanmean([0.90, 0.80, 0.96, nan, nan]) = (0.90 + 0.80 + 0.96) / 3 = 0.89
42
+
43
+ def IntersectionOverUnion(self):
44
+ # Intersection = TP Union = TP + FP + FN
45
+ # IoU = TP / (TP + FP + FN)
46
+ intersection = np.diag(self.confusionMatrix) # 取对角元素的值,返回列表
47
+ union = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag(
48
+ self.confusionMatrix) # axis = 1表示混淆矩阵行的值,返回列表; axis = 0表示取混淆矩阵列的值,返回列表
49
+ IoU = intersection / union # 返回列表,其值为各个类别的IoU
50
+ return IoU
51
+
52
+ def meanIntersectionOverUnion(self):
53
+ mIoU = np.nanmean(self.IntersectionOverUnion()) # 求各类别IoU的平均
54
+ return mIoU
55
+
56
+ def genConfusionMatrix(self, imgPredict, imgLabel): #
57
+ """
58
+ 同FCN中score.py的fast_hist()函数,计算混淆矩阵
59
+ :param imgPredict:
60
+ :param imgLabel:
61
+ :return: 混淆矩阵
62
+ """
63
+ # remove classes from unlabeled pixels in gt image and predict
64
+ mask = (imgLabel >= 0) & (imgLabel < self.numClass)
65
+ label = self.numClass * imgLabel[mask] + imgPredict[mask]
66
+ count = np.bincount(label, minlength=self.numClass ** 2)
67
+ confusionMatrix = count.reshape(self.numClass, self.numClass)
68
+ # print(confusionMatrix)
69
+ return confusionMatrix
70
+
71
+ def Frequency_Weighted_Intersection_over_Union(self):
72
+ """
73
+ FWIoU,频权交并比:为MIoU的一种提升,这种方法根据每个类出现的频率为其设置权重。
74
+ FWIOU = [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)]
75
+ """
76
+ freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
77
+ iu = np.diag(self.confusion_matrix) / (
78
+ np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
79
+ np.diag(self.confusion_matrix))
80
+ FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
81
+ return FWIoU
82
+
83
+ def addBatch(self, imgPredict, imgLabel):
84
+ assert imgPredict.shape == imgLabel.shape
85
+ self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel) # 得到混淆矩阵
86
+ return self.confusionMatrix
87
+
88
+ def reset(self):
89
+ self.confusionMatrix = np.zeros((self.numClass, self.numClass))
90
+
util/misc.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+
12
+ import builtins
13
+ import datetime
14
+ import os
15
+ import time
16
+ from collections import defaultdict, deque
17
+ from pathlib import Path
18
+
19
+ import torch
20
+ import torch.distributed as dist
21
+ from torch import inf
22
+
23
+
24
+ class SmoothedValue(object):
25
+ """Track a series of values and provide access to smoothed values over a
26
+ window or the global series average.
27
+ """
28
+
29
+ def __init__(self, window_size=20, fmt=None):
30
+ if fmt is None:
31
+ fmt = "{median:.4f} ({global_avg:.4f})"
32
+ self.deque = deque(maxlen=window_size)
33
+ self.total = 0.0
34
+ self.count = 0
35
+ self.fmt = fmt
36
+
37
+ def update(self, value, n=1):
38
+ self.deque.append(value)
39
+ self.count += n
40
+ self.total += value * n
41
+
42
+ def synchronize_between_processes(self):
43
+ """
44
+ Warning: does not synchronize the deque!
45
+ """
46
+ if not is_dist_avail_and_initialized():
47
+ return
48
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
49
+ dist.barrier()
50
+ dist.all_reduce(t)
51
+ t = t.tolist()
52
+ self.count = int(t[0])
53
+ self.total = t[1]
54
+
55
+ @property
56
+ def median(self):
57
+ d = torch.tensor(list(self.deque))
58
+ return d.median().item()
59
+
60
+ @property
61
+ def avg(self):
62
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
63
+ return d.mean().item()
64
+
65
+ @property
66
+ def global_avg(self):
67
+ return self.total / self.count
68
+
69
+ @property
70
+ def max(self):
71
+ return max(self.deque)
72
+
73
+ @property
74
+ def value(self):
75
+ return self.deque[-1]
76
+
77
+ def __str__(self):
78
+ return self.fmt.format(
79
+ median=self.median,
80
+ avg=self.avg,
81
+ global_avg=self.global_avg,
82
+ max=self.max,
83
+ value=self.value)
84
+
85
+
86
+ class MetricLogger(object):
87
+ def __init__(self, delimiter="\t"):
88
+ self.meters = defaultdict(SmoothedValue)
89
+ self.delimiter = delimiter
90
+
91
+ def update(self, **kwargs):
92
+ for k, v in kwargs.items():
93
+ if v is None:
94
+ continue
95
+ if isinstance(v, torch.Tensor):
96
+ v = v.item()
97
+ assert isinstance(v, (float, int))
98
+ self.meters[k].update(v)
99
+
100
+ def __getattr__(self, attr):
101
+ if attr in self.meters:
102
+ return self.meters[attr]
103
+ if attr in self.__dict__:
104
+ return self.__dict__[attr]
105
+ raise AttributeError("'{}' object has no attribute '{}'".format(
106
+ type(self).__name__, attr))
107
+
108
+ def __str__(self):
109
+ loss_str = []
110
+ for name, meter in self.meters.items():
111
+ loss_str.append(
112
+ "{}: {}".format(name, str(meter))
113
+ )
114
+ return self.delimiter.join(loss_str)
115
+
116
+ def synchronize_between_processes(self):
117
+ for meter in self.meters.values():
118
+ meter.synchronize_between_processes()
119
+
120
+ def add_meter(self, name, meter):
121
+ self.meters[name] = meter
122
+
123
+ def log_every(self, iterable, print_freq, header=None):
124
+ i = 0
125
+ if not header:
126
+ header = ''
127
+ start_time = time.time()
128
+ end = time.time()
129
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
130
+ data_time = SmoothedValue(fmt='{avg:.4f}')
131
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
132
+ log_msg = [
133
+ header,
134
+ '[{0' + space_fmt + '}/{1}]',
135
+ 'eta: {eta}',
136
+ '{meters}',
137
+ 'time: {time}',
138
+ 'data: {data}'
139
+ ]
140
+ if torch.cuda.is_available():
141
+ log_msg.append('max mem: {memory:.0f}')
142
+ log_msg = self.delimiter.join(log_msg)
143
+ MB = 1024.0 * 1024.0
144
+ for obj in iterable:
145
+ data_time.update(time.time() - end)
146
+ yield obj
147
+ iter_time.update(time.time() - end)
148
+ if i % print_freq == 0 or i == len(iterable) - 1:
149
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
150
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
151
+ if torch.cuda.is_available():
152
+ print(log_msg.format(
153
+ i, len(iterable), eta=eta_string,
154
+ meters=str(self),
155
+ time=str(iter_time), data=str(data_time),
156
+ memory=torch.cuda.max_memory_allocated() / MB))
157
+ else:
158
+ print(log_msg.format(
159
+ i, len(iterable), eta=eta_string,
160
+ meters=str(self),
161
+ time=str(iter_time), data=str(data_time)))
162
+ i += 1
163
+ end = time.time()
164
+ total_time = time.time() - start_time
165
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
166
+ print('{} Total time: {} ({:.4f} s / it)'.format(
167
+ header, total_time_str, total_time / len(iterable)))
168
+
169
+
170
+ def setup_for_distributed(is_master):
171
+ """
172
+ This function disables printing when not in master process
173
+ """
174
+ builtin_print = builtins.print
175
+
176
+ def print(*args, **kwargs):
177
+ force = kwargs.pop('force', False)
178
+ force = force or (get_world_size() > 8)
179
+ if is_master or force:
180
+ now = datetime.datetime.now().time()
181
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
182
+ builtin_print(*args, **kwargs)
183
+
184
+ builtins.print = print
185
+
186
+
187
+ def is_dist_avail_and_initialized():
188
+ if not dist.is_available():
189
+ return False
190
+ if not dist.is_initialized():
191
+ return False
192
+ return True
193
+
194
+
195
+ def get_world_size():
196
+ if not is_dist_avail_and_initialized():
197
+ return 1
198
+ return dist.get_world_size()
199
+
200
+
201
+ def get_rank():
202
+ if not is_dist_avail_and_initialized():
203
+ return 0
204
+ return dist.get_rank()
205
+
206
+
207
+ def is_main_process():
208
+ return get_rank() == 0
209
+
210
+
211
+ def save_on_master(*args, **kwargs):
212
+ if is_main_process():
213
+ torch.save(*args, **kwargs)
214
+
215
+
216
+ def init_distributed_mode(args):
217
+ if args.dist_on_itp:
218
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
219
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
220
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
221
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
222
+ os.environ['LOCAL_RANK'] = str(args.gpu)
223
+ os.environ['RANK'] = str(args.rank)
224
+ os.environ['WORLD_SIZE'] = str(args.world_size)
225
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
226
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
227
+ args.rank = int(os.environ["RANK"])
228
+ args.world_size = int(os.environ['WORLD_SIZE'])
229
+ args.gpu = int(os.environ['LOCAL_RANK'])
230
+ elif 'SLURM_PROCID' in os.environ:
231
+ args.rank = int(os.environ['SLURM_PROCID'])
232
+ args.gpu = args.rank % torch.cuda.device_count()
233
+ else:
234
+ print('Not using distributed mode')
235
+ setup_for_distributed(is_master=True) # hack
236
+ args.distributed = False
237
+ return
238
+
239
+ args.distributed = True
240
+
241
+ torch.cuda.set_device(args.gpu)
242
+ args.dist_backend = 'nccl'
243
+ print('| distributed init (rank {}): {}, gpu {}'.format(
244
+ args.rank, args.dist_url, args.gpu), flush=True)
245
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
246
+ world_size=args.world_size, rank=args.rank)
247
+ torch.distributed.barrier()
248
+ setup_for_distributed(args.rank == 0)
249
+
250
+
251
+ class NativeScalerWithGradNormCount:
252
+ state_dict_key = "amp_scaler"
253
+
254
+ def __init__(self):
255
+ self._scaler = torch.cuda.amp.GradScaler()
256
+
257
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
258
+ self._scaler.scale(loss).backward(create_graph=create_graph)
259
+ if update_grad:
260
+ if clip_grad is not None:
261
+ assert parameters is not None
262
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
263
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
264
+ else:
265
+ self._scaler.unscale_(optimizer)
266
+ norm = get_grad_norm_(parameters)
267
+ self._scaler.step(optimizer)
268
+ self._scaler.update()
269
+ else:
270
+ norm = None
271
+ return norm
272
+
273
+ def state_dict(self):
274
+ return self._scaler.state_dict()
275
+
276
+ def load_state_dict(self, state_dict):
277
+ self._scaler.load_state_dict(state_dict)
278
+
279
+
280
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
281
+ if isinstance(parameters, torch.Tensor):
282
+ parameters = [parameters]
283
+ parameters = [p for p in parameters if p.grad is not None]
284
+ norm_type = float(norm_type)
285
+ if len(parameters) == 0:
286
+ return torch.tensor(0.)
287
+ device = parameters[0].grad.device
288
+ if norm_type == inf:
289
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
290
+ else:
291
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
292
+ return total_norm
293
+
294
+
295
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
296
+ output_dir = Path(args.output_dir)
297
+ epoch_name = str(epoch)
298
+ if loss_scaler is not None:
299
+ checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
300
+ for checkpoint_path in checkpoint_paths:
301
+ to_save = {
302
+ 'model': model_without_ddp.state_dict(),
303
+ 'optimizer': optimizer.state_dict(),
304
+ 'epoch': epoch,
305
+ 'scaler': loss_scaler.state_dict(),
306
+ 'args': args,
307
+ }
308
+
309
+ save_on_master(to_save, checkpoint_path)
310
+ else:
311
+ client_state = {'epoch': epoch}
312
+ model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
313
+
314
+
315
+ def load_model(args, model_without_ddp, optimizer, loss_scaler):
316
+ if args.resume:
317
+ if args.resume.startswith('https'):
318
+ checkpoint = torch.hub.load_state_dict_from_url(
319
+ args.resume, map_location='cpu', check_hash=True)
320
+ else:
321
+ checkpoint = torch.load(args.resume, map_location='cpu')
322
+ model_without_ddp.load_state_dict(checkpoint['model'])
323
+ print("Resume checkpoint %s" % args.resume)
324
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
325
+ optimizer.load_state_dict(checkpoint['optimizer'])
326
+ args.start_epoch = checkpoint['epoch'] + 1
327
+ if 'scaler' in checkpoint:
328
+ loss_scaler.load_state_dict(checkpoint['scaler'])
329
+ print("With optim & sched!")
330
+
331
+
332
+ def all_reduce_mean(x):
333
+ world_size = get_world_size()
334
+ if world_size > 1:
335
+ x_reduce = torch.tensor(x).cuda()
336
+ dist.all_reduce(x_reduce)
337
+ x_reduce /= world_size
338
+ return x_reduce.item()
339
+ else:
340
+ return x
util/msssim.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from math import exp
4
+
5
+
6
+ def gaussian(window_size, sigma):
7
+ gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
8
+ return gauss/gauss.sum()
9
+
10
+
11
+ def create_window(window_size, channel=1):
12
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
13
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
14
+ window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
15
+ return window
16
+
17
+
18
+ def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
19
+ # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
20
+ if val_range is None:
21
+ if torch.max(img1) > 128:
22
+ max_val = 255
23
+ else:
24
+ max_val = 1
25
+
26
+ if torch.min(img1) < -0.5:
27
+ min_val = -1
28
+ else:
29
+ min_val = 0
30
+ L = max_val - min_val
31
+ else:
32
+ L = val_range
33
+
34
+ padd = 0
35
+ (_, channel, height, width) = img1.size()
36
+ if window is None:
37
+ real_size = min(window_size, height, width)
38
+ window = create_window(real_size, channel=channel).to(img1.device)
39
+
40
+ mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
41
+ mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
42
+
43
+ mu1_sq = mu1.pow(2)
44
+ mu2_sq = mu2.pow(2)
45
+ mu1_mu2 = mu1 * mu2
46
+
47
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
48
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
49
+ sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
50
+
51
+ C1 = (0.01 * L) ** 2
52
+ C2 = (0.03 * L) ** 2
53
+
54
+ v1 = 2.0 * sigma12 + C2
55
+ v2 = sigma1_sq + sigma2_sq + C2
56
+ cs = torch.mean(v1 / v2) # contrast sensitivity
57
+
58
+ ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
59
+
60
+ if size_average:
61
+ ret = ssim_map.mean()
62
+ else:
63
+ ret = ssim_map.mean(1).mean(1).mean(1)
64
+
65
+ if full:
66
+ return ret, cs
67
+ return ret
68
+
69
+
70
+ def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=True):
71
+ device = img1.device
72
+ weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
73
+ levels = weights.size()[0]
74
+ mssim = []
75
+ mcs = []
76
+ for _ in range(levels):
77
+ sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
78
+ mssim.append(sim)
79
+ mcs.append(cs)
80
+
81
+ img1 = F.avg_pool2d(img1, (2, 2))
82
+ img2 = F.avg_pool2d(img2, (2, 2))
83
+
84
+ mssim = torch.stack(mssim)
85
+ mcs = torch.stack(mcs)
86
+
87
+ # Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
88
+ if normalize:
89
+ mssim = (mssim + 1) / 2
90
+ mcs = (mcs + 1) / 2
91
+
92
+ pow1 = mcs ** weights
93
+ pow2 = mssim ** weights
94
+ # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
95
+ output = torch.prod(pow1[:-1] * pow2[-1])
96
+ return output
97
+
98
+
99
+ # Classes to re-use window
100
+ class SSIM(torch.nn.Module):
101
+ def __init__(self, window_size=11, size_average=True, val_range=None):
102
+ super(SSIM, self).__init__()
103
+ self.window_size = window_size
104
+ self.size_average = size_average
105
+ self.val_range = val_range
106
+
107
+ # Assume 1 channel for SSIM
108
+ self.channel = 1
109
+ self.window = create_window(window_size)
110
+
111
+ def forward(self, img1, img2):
112
+ (_, channel, _, _) = img1.size()
113
+
114
+ if channel == self.channel and self.window.dtype == img1.dtype:
115
+ window = self.window
116
+ else:
117
+ window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
118
+ self.window = window
119
+ self.channel = channel
120
+
121
+ return 1 - ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
122
+
123
+ class MSSSIM(torch.nn.Module):
124
+ def __init__(self, window_size=11, size_average=True, channel=1):
125
+ super(MSSSIM, self).__init__()
126
+ self.window_size = window_size
127
+ self.size_average = size_average
128
+ self.channel = channel
129
+
130
+ def forward(self, img1, img2):
131
+ # TODO: store window between calls if possible
132
+ return 1 - msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
133
+
134
+ class PSNR(torch.nn.Module):
135
+ def __init__(self):
136
+ super(PSNR, self).__init__()
137
+
138
+ def torchPSNR(self,tar_img, prd_img):
139
+ imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
140
+ rmse = (imdff**2).mean().sqrt()
141
+ ps = 20*torch.log10(1/rmse)
142
+ return ps
143
+
144
+ def forward(self, img1, img2):
145
+ # TODO: store window between calls if possible
146
+ return self.torchPSNR(img1, img2)
util/pos_embed.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Position embedding utils
8
+ # --------------------------------------------------------
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+
14
+ # --------------------------------------------------------
15
+ # 2D sine-cosine position embedding
16
+ # References:
17
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
19
+ # --------------------------------------------------------
20
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21
+ """
22
+ grid_size: int of the grid height and width
23
+ return:
24
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25
+ """
26
+ grid_h = np.arange(grid_size, dtype=np.float32)
27
+ grid_w = np.arange(grid_size, dtype=np.float32)
28
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
29
+ grid = np.stack(grid, axis=0)
30
+
31
+ grid = grid.reshape([2, 1, grid_size, grid_size])
32
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33
+ if cls_token:
34
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35
+ return pos_embed
36
+
37
+
38
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39
+ assert embed_dim % 2 == 0
40
+
41
+ # use half of dimensions to encode grid_h
42
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
43
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
44
+
45
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46
+ return emb
47
+
48
+
49
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50
+ """
51
+ embed_dim: output dimension for each position
52
+ pos: a list of positions to be encoded: size (M,)
53
+ out: (M, D)
54
+ """
55
+ assert embed_dim % 2 == 0
56
+ omega = np.arange(embed_dim // 2, dtype=np.float)
57
+ omega /= embed_dim / 2.
58
+ omega = 1. / 10000**omega # (D/2,)
59
+
60
+ pos = pos.reshape(-1) # (M,)
61
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
62
+
63
+ emb_sin = np.sin(out) # (M, D/2)
64
+ emb_cos = np.cos(out) # (M, D/2)
65
+
66
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
67
+ return emb
68
+
69
+
70
+ # --------------------------------------------------------
71
+ # Interpolate position embeddings for high-resolution
72
+ # References:
73
+ # DeiT: https://github.com/facebookresearch/deit
74
+ # --------------------------------------------------------
75
+ def interpolate_pos_embed(model, checkpoint_model,newsize1=None,newsize2=None):
76
+ if 'pos_embed' in checkpoint_model:
77
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
78
+ embedding_size = pos_embed_checkpoint.shape[-1]
79
+ num_patches = model.patch_embed.num_patches
80
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
81
+ # height (== width) for the checkpoint position embedding
82
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
83
+ # height (== width) for the new position embedding
84
+ new_size = int(num_patches ** 0.5)
85
+ # class_token and dist_token are kept unchanged
86
+ if orig_size != new_size:
87
+ if newsize1 == None:
88
+ newsize1,newsize2 = new_size,new_size
89
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, newsize1, newsize2))
90
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
91
+ # only the position tokens are interpolated
92
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
93
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
94
+ pos_tokens = torch.nn.functional.interpolate(
95
+ pos_tokens, size=(newsize1, newsize2), mode='bicubic', align_corners=False)
96
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
97
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
98
+ checkpoint_model['pos_embed'] = new_pos_embed
99
+ # elif orig_size > new_size:
100
+ # print("Position generate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
101
+ # pos_tokens = get_2d_sincos_pos_embed(embedding_size, new_size, cls_token=True)
102
+ # pos_tokens = torch.from_numpy(pos_tokens).float().unsqueeze(0)
103
+ # checkpoint_model['pos_embed'] = pos_tokens
104
+
util/pos_embedtest.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Position embedding utils
8
+ # --------------------------------------------------------
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+
14
+ # --------------------------------------------------------
15
+ # 2D sine-cosine position embedding
16
+ # References:
17
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
19
+ # --------------------------------------------------------
20
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21
+ """
22
+ grid_size: int of the grid height and width
23
+ return:
24
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25
+ """
26
+ grid_h = np.arange(grid_size, dtype=np.float32)
27
+ grid_w = np.arange(grid_size, dtype=np.float32)
28
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
29
+ grid = np.stack(grid, axis=0)
30
+
31
+ grid = grid.reshape([2, 1, grid_size, grid_size])
32
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33
+ if cls_token:
34
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35
+ return pos_embed
36
+
37
+
38
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39
+ assert embed_dim % 2 == 0
40
+
41
+ # use half of dimensions to encode grid_h
42
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
43
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
44
+
45
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46
+ return emb
47
+
48
+
49
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50
+ """
51
+ embed_dim: output dimension for each position
52
+ pos: a list of positions to be encoded: size (M,)
53
+ out: (M, D)
54
+ """
55
+ assert embed_dim % 2 == 0
56
+ omega = np.arange(embed_dim // 2, dtype=np.float)
57
+ omega /= embed_dim / 2.
58
+ omega = 1. / 10000**omega # (D/2,)
59
+
60
+ pos = pos.reshape(-1) # (M,)
61
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
62
+
63
+ emb_sin = np.sin(out) # (M, D/2)
64
+ emb_cos = np.cos(out) # (M, D/2)
65
+
66
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
67
+ return emb
68
+
69
+
70
+ # --------------------------------------------------------
71
+ # Interpolate position embeddings for high-resolution
72
+ # References:
73
+ # DeiT: https://github.com/facebookresearch/deit
74
+ # --------------------------------------------------------
75
+ def interpolate_pos_embed(model, checkpoint_model,newsize1=None,newsize2=None):
76
+ if 'pos_embed' in checkpoint_model:
77
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
78
+ embedding_size = pos_embed_checkpoint.shape[-1]
79
+ num_patches = model.patch_embed.num_patches
80
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
81
+ # height (== width) for the checkpoint position embedding
82
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
83
+ # height (== width) for the new position embedding
84
+ new_size = int(num_patches ** 0.5)
85
+ # class_token and dist_token are kept unchanged
86
+ if orig_size != new_size:
87
+ if newsize1 == None:
88
+ newsize1,newsize2 = new_size,new_size
89
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, newsize1, newsize2))
90
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
91
+ # only the position tokens are interpolated
92
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
93
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
94
+ pos_tokens = torch.nn.functional.interpolate(
95
+ pos_tokens, size=(newsize1, newsize2), mode='bicubic', align_corners=False)
96
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
97
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
98
+ checkpoint_model['pos_embed'] = new_pos_embed
99
+
100
+ def interpolate_dec_embed(model, checkpoint_model):
101
+ if 'decoder_pos_embed' in checkpoint_model:
102
+ pos_embed_checkpoint = checkpoint_model['decoder_pos_embed']
103
+ embedding_size = pos_embed_checkpoint.shape[-1]
104
+ num_patches = model.decoder_pos_embed.num_patches
105
+ num_extra_tokens = model.decoder_pos_embed.shape[-2] - num_patches
106
+ # height (== width) for the checkpoint position embedding
107
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
108
+ # height (== width) for the new position embedding
109
+ new_size = int(num_patches ** 0.5)
110
+ # class_token and dist_token are kept unchanged
111
+ if orig_size != new_size:
112
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
113
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
114
+ # only the position tokens are interpolated
115
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
116
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
117
+ pos_tokens = torch.nn.functional.interpolate(
118
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
119
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
120
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
121
+ checkpoint_model['decoder_pos_embed'] = new_pos_embed
122
+ # elif orig_size > new_size:
123
+ # print("Position generate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
124
+ # pos_tokens = get_2d_sincos_pos_embed(embedding_size, new_size, cls_token=True)
125
+ # pos_tokens = torch.from_numpy(pos_tokens).float().unsqueeze(0)
126
+ # checkpoint_model['pos_embed'] = pos_tokens
127
+
util/post_processing.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import PIL.Image as Image
4
+ import torchvision.transforms as transforms
5
+ import torch.nn.functional as F
6
+ from typing import Optional, Tuple, Union
7
+
8
+
9
+ def morphological_open(image: torch.Tensor, kernel_size: int = 3) -> torch.Tensor:
10
+ """
11
+ Perform morphological opening on a 2D torch tensor (image).
12
+
13
+ Args:
14
+ image (torch.Tensor): image to open
15
+ kernel_size (int): size of the structuring element - roughly the size of hole to be opened
16
+
17
+ Returns:
18
+ torch.Tensor: The opened image.
19
+ """
20
+ kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32, device=image.device)
21
+ eroded = F.conv2d(image.unsqueeze(0), kernel, stride=1, padding=kernel_size // 2)
22
+ eroded = (eroded > 0).float()
23
+ dilated = F.conv2d(eroded, kernel, stride=1, padding=kernel_size // 2)
24
+ return (dilated > 0).float()
25
+
26
+ def morphological_close(image: torch.Tensor, kernel_size: int = 3) -> torch.Tensor:
27
+ """
28
+ Perform morphological closing on a 2D torch tensor (image).
29
+
30
+ Args:
31
+ image (torch.Tensor): image to close
32
+ kernel_size (int): size of the structuring element - roughly the size of hole to be closed
33
+
34
+ Returns:
35
+ torch.Tensor: The closed image.
36
+ """
37
+ kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32, device=image.device)
38
+ dilated = F.conv2d(image.unsqueeze(0), kernel, stride=1, padding=kernel_size // 2)
39
+ dilated = (dilated > 0).float()
40
+ eroded = F.conv2d(dilated, kernel, stride=1, padding=kernel_size // 2)
41
+ return (eroded > 0).float()
42
+
43
+ def gaussian_convolve(image: torch.Tensor, kernel_size: int = 5, sigma: float = 1.0) -> torch.Tensor:
44
+ """
45
+ Gaussian Convolution to smooth image
46
+
47
+ Args:
48
+ image (torch.Tensor): image to convolve
49
+ kernel_size (int): size of the Gaussian kernel
50
+ sigma (float): standard deviation of the Gaussian distribution
51
+
52
+ Returns:
53
+ torch.Tensor: The convolved image.
54
+ """
55
+ x = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=torch.float32)
56
+ y = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=torch.float32)
57
+ x, y = torch.meshgrid(x, y)
58
+ kernel = torch.exp(-(x**2 + y**2) / (2 * sigma**2))
59
+ kernel = kernel / kernel.sum()
60
+ # Apply the Gaussian kernel
61
+ return F.conv2d(image.unsqueeze(0), kernel.unsqueeze(0).unsqueeze(0), stride=1, padding=kernel_size // 2)
62
+
63
+ def hysteresis_filter(image: torch.Tensor, low_threshold: float, high_threshold: float) -> torch.Tensor:
64
+ """
65
+ Hysteresis Filter Function - for Canny Edge detection
66
+
67
+ Args:
68
+ image (torch.Tensor): image to process
69
+ low_threshold (float): low threshold for hysteresis
70
+ high_threshold (float): high threshold for hysteresis
71
+
72
+ Returns:
73
+ edge (torch.Tensor): The edges detected in the image.
74
+
75
+ """
76
+ edges = (image > high_threshold).float()
77
+ # Perform hysteresis thresholding
78
+ edges = torch.where(image > low_threshold, edges, 0)
79
+ return edges
80
+
81
+ def non_maxima_suppression_2d(
82
+ image: torch.Tensor,
83
+ kernel_size: int = 3,
84
+ threshold: Optional[float] = None,
85
+ return_mask: bool = False
86
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
87
+ """
88
+ Perform non-maxima suppression on a 2D torch tensor (image).
89
+
90
+ Args:
91
+ image (torch.Tensor): Input tensor of shape (H, W) or (B, C, H, W) or (C, H, W)
92
+ kernel_size (int): Size of the local neighborhood for maxima detection (default: 3)
93
+ threshold (float, optional): Minimum value threshold for considering pixels
94
+ return_mask (bool): If True, return both suppressed image and binary mask
95
+
96
+ Returns:
97
+ torch.Tensor: Image with non-maxima suppressed
98
+ torch.Tensor (optional): Binary mask of local maxima if return_mask=True
99
+ """
100
+ original_shape = image.shape
101
+
102
+ # Handle different input shapes
103
+ if len(image.shape) == 2: # (H, W)
104
+ image = image.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
105
+ elif len(image.shape) == 3: # (C, H, W)
106
+ image = image.unsqueeze(0) # (1, C, H, W)
107
+ elif len(image.shape) == 4: # (B, C, H, W)
108
+ pass
109
+ else:
110
+ raise ValueError(f"Unsupported tensor shape: {original_shape}")
111
+
112
+ batch_size, channels, height, width = image.shape
113
+
114
+ # Apply threshold if specified
115
+ if threshold is not None:
116
+ image = torch.where(image >= threshold, image, torch.tensor(0.0, device=image.device))
117
+
118
+ # Perform max pooling to find local maxima
119
+ padding = kernel_size // 2
120
+ max_pooled = F.max_pool2d(image, kernel_size=kernel_size, stride=1, padding=padding)
121
+
122
+ # Create mask where original values equal max pooled values (local maxima)
123
+ mask = (image == max_pooled) & (image > 0)
124
+
125
+ # Apply non-maxima suppression
126
+ suppressed = image * mask.float()
127
+
128
+ # Reshape back to original shape
129
+ if len(original_shape) == 2:
130
+ suppressed = suppressed.squeeze(0).squeeze(0)
131
+ mask = mask.squeeze(0).squeeze(0)
132
+ elif len(original_shape) == 3:
133
+ suppressed = suppressed.squeeze(0)
134
+ mask = mask.squeeze(0)
135
+
136
+ if return_mask:
137
+ return suppressed, mask
138
+ return suppressed
139
+
140
+
141
+ def non_maxima_suppression_with_orientation(
142
+ magnitude: torch.Tensor,
143
+ orientation: torch.Tensor,
144
+ threshold: Optional[float] = None
145
+ ) -> torch.Tensor:
146
+ """
147
+ Perform oriented non-maxima suppression (commonly used in edge detection).
148
+
149
+ Args:
150
+ magnitude (torch.Tensor): Gradient magnitude tensor of shape (H, W) or (B, C, H, W)
151
+ orientation (torch.Tensor): Gradient orientation tensor (in radians) of same shape
152
+ threshold (float, optional): Minimum magnitude threshold
153
+
154
+ Returns:
155
+ torch.Tensor: Non-maxima suppressed magnitude
156
+ """
157
+ original_shape = magnitude.shape
158
+
159
+ # Handle different input shapes
160
+ if len(magnitude.shape) == 2:
161
+ magnitude = magnitude.unsqueeze(0).unsqueeze(0)
162
+ orientation = orientation.unsqueeze(0).unsqueeze(0)
163
+ elif len(magnitude.shape) == 3:
164
+ magnitude = magnitude.unsqueeze(0)
165
+ orientation = orientation.unsqueeze(0)
166
+
167
+ batch_size, channels, height, width = magnitude.shape
168
+ device = magnitude.device
169
+
170
+ # Apply threshold if specified
171
+ if threshold is not None:
172
+ magnitude = torch.where(magnitude >= threshold, magnitude, torch.tensor(0.0, device=device))
173
+
174
+ # Convert orientation to degrees and normalize to [0, 180)
175
+ angle = torch.rad2deg(orientation) % 180
176
+
177
+ # Create padded magnitude for neighbor comparison
178
+ mag_padded = F.pad(magnitude, (1, 1, 1, 1), mode='constant', value=0)
179
+
180
+ # Initialize output
181
+ suppressed = torch.zeros_like(magnitude)
182
+
183
+ # Define 8-connectivity neighbors
184
+ for b in range(batch_size):
185
+ for c in range(channels):
186
+ mag = magnitude[b, c]
187
+ ang = angle[b, c]
188
+ mag_pad = mag_padded[b, c]
189
+
190
+ for i in range(1, height + 1):
191
+ for j in range(1, width + 1):
192
+ current_mag = mag_pad[i, j]
193
+ current_angle = ang[i-1, j-1]
194
+
195
+ if current_mag == 0:
196
+ continue
197
+
198
+ # Determine interpolation direction based on angle
199
+ if (0 <= current_angle < 22.5) or (157.5 <= current_angle < 180):
200
+ # Horizontal direction (0°)
201
+ neighbor1 = mag_pad[i, j-1]
202
+ neighbor2 = mag_pad[i, j+1]
203
+ elif 22.5 <= current_angle < 67.5:
204
+ # Diagonal direction (45°)
205
+ neighbor1 = mag_pad[i-1, j+1]
206
+ neighbor2 = mag_pad[i+1, j-1]
207
+ elif 67.5 <= current_angle < 112.5:
208
+ # Vertical direction (90°)
209
+ neighbor1 = mag_pad[i-1, j]
210
+ neighbor2 = mag_pad[i+1, j]
211
+ else: # 112.5 <= current_angle < 157.5
212
+ # Diagonal direction (135°)
213
+ neighbor1 = mag_pad[i-1, j-1]
214
+ neighbor2 = mag_pad[i+1, j+1]
215
+
216
+ # Keep pixel if it's a local maximum
217
+ if current_mag >= neighbor1 and current_mag >= neighbor2:
218
+ suppressed[b, c, i-1, j-1] = current_mag
219
+
220
+ # Reshape back to original shape
221
+ if len(original_shape) == 2:
222
+ suppressed = suppressed.squeeze(0).squeeze(0)
223
+ elif len(original_shape) == 3:
224
+ suppressed = suppressed.squeeze(0)
225
+
226
+ return suppressed
227
+
228
+
229
+ def adaptive_non_maxima_suppression(
230
+ image: torch.Tensor,
231
+ num_points: int,
232
+ min_distance: int = 5,
233
+ threshold: Optional[float] = None
234
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
235
+ """
236
+ Adaptive non-maxima suppression that selects a fixed number of strongest points
237
+ while maintaining minimum distance between them.
238
+
239
+ Args:
240
+ image (torch.Tensor): Input tensor of shape (H, W)
241
+ num_points (int): Number of points to select
242
+ min_distance (int): Minimum distance between selected points
243
+ threshold (float, optional): Minimum value threshold
244
+
245
+ Returns:
246
+ Tuple[torch.Tensor, torch.Tensor]: Coordinates (y, x) and values of selected points
247
+ """
248
+ if len(image.shape) != 2:
249
+ raise ValueError("Input must be a 2D tensor")
250
+
251
+ height, width = image.shape
252
+ device = image.device
253
+
254
+ # Apply threshold if specified
255
+ if threshold is not None:
256
+ image = torch.where(image >= threshold, image, torch.tensor(0.0, device=device))
257
+
258
+ # Find all local maxima using simple NMS
259
+ nms_result = non_maxima_suppression_2d(image, kernel_size=3)
260
+
261
+ # Get coordinates and values of all local maxima
262
+ y_coords, x_coords = torch.nonzero(nms_result > 0, as_tuple=True)
263
+ values = nms_result[y_coords, x_coords]
264
+
265
+ if len(values) == 0:
266
+ return torch.empty((0, 2), device=device), torch.empty(0, device=device)
267
+
268
+ # Sort by strength (descending)
269
+ sorted_indices = torch.argsort(values, descending=True)
270
+ y_coords = y_coords[sorted_indices]
271
+ x_coords = x_coords[sorted_indices]
272
+ values = values[sorted_indices]
273
+
274
+ # Select points with minimum distance constraint
275
+ selected_coords = []
276
+ selected_values = []
277
+
278
+ for i in range(len(values)):
279
+ if len(selected_coords) >= num_points:
280
+ break
281
+
282
+ current_y, current_x = y_coords[i].item(), x_coords[i].item()
283
+ current_val = values[i].item()
284
+
285
+ # Check distance to all previously selected points
286
+ valid = True
287
+ for sel_y, sel_x in selected_coords:
288
+ distance = ((current_y - sel_y) ** 2 + (current_x - sel_x) ** 2) ** 0.5
289
+ if distance < min_distance:
290
+ valid = False
291
+ break
292
+
293
+ if valid:
294
+ selected_coords.append((current_y, current_x))
295
+ selected_values.append(current_val)
296
+
297
+ if selected_coords:
298
+ coords_tensor = torch.tensor(selected_coords, device=device, dtype=torch.float32)
299
+ values_tensor = torch.tensor(selected_values, device=device, dtype=torch.float32)
300
+ else:
301
+ coords_tensor = torch.empty((0, 2), device=device)
302
+ values_tensor = torch.empty(0, device=device)
303
+
304
+ return coords_tensor, values_tensor
305
+
util/size_aware_batching.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Size-aware batching utilities for variable-sized seismic images
3
+ """
4
+
5
+ import torch
6
+ from torch.utils.data import DataLoader, Sampler
7
+ import numpy as np
8
+ from collections import defaultdict
9
+ import random
10
+
11
+
12
+ class SizeAwareSampler(Sampler):
13
+ """
14
+ Groups samples by size and creates batches with images of the same size
15
+ """
16
+ def __init__(self, dataset, batch_size, get_size_fn=None):
17
+ """
18
+ Args:
19
+ dataset: PyTorch dataset
20
+ batch_size: batch size for each size group
21
+ get_size_fn: function that takes dataset index and returns (height, width)
22
+ If None, will try to infer from dataset
23
+ """
24
+ self.dataset = dataset
25
+ self.batch_size = batch_size
26
+ self.get_size_fn = get_size_fn
27
+
28
+ # Group indices by size
29
+ self.size_groups = self._group_by_size()
30
+
31
+ # Create batches
32
+ self.batches = self._create_batches()
33
+
34
+ def _group_by_size(self):
35
+ """Group dataset indices by image size"""
36
+ size_groups = defaultdict(list)
37
+
38
+ for idx in range(len(self.dataset)):
39
+ if self.get_size_fn:
40
+ size = self.get_size_fn(idx)
41
+ else:
42
+ # Try to get size from dataset item
43
+ sample = self.dataset[idx]
44
+ if isinstance(sample, (tuple, list)):
45
+ # Assume first element is the image tensor
46
+ img_tensor = sample[0]
47
+ else:
48
+ img_tensor = sample
49
+
50
+ # Get size from tensor shape (assuming shape is [C, H, W] or [H, W])
51
+ if len(img_tensor.shape) == 3:
52
+ size = (img_tensor.shape[1], img_tensor.shape[2]) # H, W
53
+ elif len(img_tensor.shape) == 2:
54
+ size = (img_tensor.shape[0], img_tensor.shape[1]) # H, W
55
+ else:
56
+ raise ValueError(f"Unexpected tensor shape: {img_tensor.shape}")
57
+
58
+ size_groups[size].append(idx)
59
+
60
+ return size_groups
61
+ def _create_batches(self, random_size = True):
62
+ """Create batches from size groups"""
63
+ batches = []
64
+
65
+ for size, indices in self.size_groups.items():
66
+ # Shuffle indices within each size group
67
+ random.shuffle(indices)
68
+
69
+ # Create batches of the specified size
70
+ for i in range(0, len(indices), self.batch_size):
71
+ batch = indices[i:i + self.batch_size]
72
+ batches.append(batch)
73
+
74
+ return batches
75
+
76
+ def __iter__(self):
77
+ # Shuffle the order of batches
78
+ random.shuffle(self.batches)
79
+ for batch in self.batches:
80
+ yield batch
81
+
82
+ def __len__(self):
83
+ return len(self.batches)
84
+
85
+
86
+ class FixedSizeSampler(Sampler):
87
+ """
88
+ Sampler for datasets where you know the exact 3 size categories
89
+ More efficient than SizeAwareSampler when sizes are known
90
+ """
91
+ def __init__(self, dataset, batch_size, size_categories):
92
+ """
93
+ Args:
94
+ dataset: PyTorch dataset
95
+ batch_size: batch size for each size category
96
+ size_categories: list of (height, width) tuples for the 3 categories
97
+ e.g., [(601, 200), (200, 255), (601, 255)]
98
+ """
99
+ self.dataset = dataset
100
+ self.batch_size = batch_size
101
+ self.size_categories = size_categories
102
+
103
+ # Map indices to size categories
104
+ self.size_to_indices = {size: [] for size in size_categories}
105
+ self._categorize_indices()
106
+
107
+ # Create batches
108
+ self.batches = self._create_batches()
109
+
110
+ def _categorize_indices(self):
111
+ """Categorize dataset indices by their size"""
112
+ for idx in range(len(self.dataset)):
113
+ sample = self.dataset[idx]
114
+ if isinstance(sample, (tuple, list)):
115
+ img_tensor = sample[0]
116
+ else:
117
+ img_tensor = sample
118
+
119
+ # Get size from tensor
120
+ if len(img_tensor.shape) == 3:
121
+ size = (img_tensor.shape[1], img_tensor.shape[2])
122
+ elif len(img_tensor.shape) == 2:
123
+ size = (img_tensor.shape[0], img_tensor.shape[1])
124
+ else:
125
+ raise ValueError(f"Unexpected tensor shape: {img_tensor.shape}")
126
+
127
+ # Find matching category
128
+ if size in self.size_categories:
129
+ self.size_to_indices[size].append(idx)
130
+ else:
131
+ # Find closest size category (optional)
132
+ closest_size = min(self.size_categories,
133
+ key=lambda cat: abs(cat[0] - size[0]) + abs(cat[1] - size[1]))
134
+ print(f"Warning: Size {size} not in categories, assigning to {closest_size}")
135
+ self.size_to_indices[closest_size].append(idx)
136
+
137
+ def _create_batches(self, random_size = True):
138
+ """Create batches from size categories"""
139
+ batches = []
140
+
141
+ for size, indices in self.size_to_indices.items():
142
+ if not indices:
143
+ continue
144
+
145
+ # Shuffle indices within each size category
146
+ random.shuffle(indices)
147
+
148
+ # Create batches
149
+ for i in range(0, len(indices), self.batch_size):
150
+ batch = indices[i:i + self.batch_size]
151
+ batches.append(batch)
152
+
153
+ return batches
154
+
155
+ def __iter__(self):
156
+ # Shuffle the order of batches across all size categories
157
+ random.shuffle(self.batches)
158
+ for batch in self.batches:
159
+ yield batch
160
+
161
+ def __len__(self):
162
+ return len(self.batches)
163
+
164
+ def get_size_distribution(self):
165
+ """Get the distribution of samples across size categories"""
166
+ distribution = {}
167
+ for size, indices in self.size_to_indices.items():
168
+ distribution[size] = len(indices)
169
+ return distribution
170
+
171
+
172
+ def create_size_aware_dataloader(dataset, batch_size=8, size_categories=None,
173
+ num_workers=4, pin_memory=True, **kwargs):
174
+ """
175
+ Create a DataLoader that batches samples by size
176
+
177
+ Args:
178
+ dataset: PyTorch dataset
179
+ batch_size: batch size for each size group
180
+ size_categories: list of (height, width) tuples for known size categories
181
+ If None, will auto-detect sizes
182
+ num_workers: number of worker processes
183
+ pin_memory: whether to pin memory
184
+ **kwargs: additional arguments for DataLoader
185
+
186
+ Returns:
187
+ DataLoader with size-aware batching
188
+ """
189
+ if size_categories:
190
+ sampler = FixedSizeSampler(dataset, batch_size, size_categories)
191
+ else:
192
+ sampler = SizeAwareSampler(dataset, batch_size)
193
+
194
+ # Remove batch_size from kwargs since we're using a custom sampler
195
+ kwargs.pop('batch_size', None)
196
+ kwargs.pop('shuffle', None) # Sampler handles shuffling
197
+
198
+ return DataLoader(
199
+ dataset,
200
+ batch_sampler=sampler,
201
+ num_workers=num_workers,
202
+ pin_memory=pin_memory,
203
+ **kwargs
204
+ )
205
+
206
+
207
+ # Custom collate function for same-size batches (no padding needed)
208
+ def same_size_collate_fn(batch):
209
+ """
210
+ Collate function for batches where all items have the same size
211
+ No padding required since all images in batch are same size
212
+ """
213
+ if isinstance(batch[0], (tuple, list)):
214
+ # Assuming (image, target) pairs
215
+ images, targets = zip(*batch)
216
+ return torch.stack(images), torch.stack(targets)
217
+ else:
218
+ # Just images
219
+ return torch.stack(batch)
220
+
221
+
222
+
223
+ # Utility function to check batch sizes
224
+ def validate_batch_sizes(dataloader, num_batches_to_check=5):
225
+ """
226
+ Validate that all images in each batch have the same size
227
+ """
228
+ print("Validating batch sizes...")
229
+
230
+ for i, batch in enumerate(dataloader):
231
+ if i >= num_batches_to_check:
232
+ break
233
+
234
+ if isinstance(batch, (tuple, list)):
235
+ images = batch[0]
236
+ else:
237
+ images = batch
238
+
239
+ batch_size = images.shape[0]
240
+ height = images.shape[2]
241
+ width = images.shape[3]
242
+
243
+ print(f"Batch {i}: {batch_size} images of size {height}x{width}")
244
+
245
+ # Verify all images in batch have same size
246
+ for j in range(batch_size):
247
+ img_h, img_w = images[j].shape[1], images[j].shape[2]
248
+ if img_h != height or img_w != width:
249
+ print(f" WARNING: Image {j} has different size {img_h}x{img_w}")
250
+
251
+ print("Validation complete!")
util/skeletonize.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Courtesy of Martin Mentan:
3
+
4
+ Works Cited
5
+ Menten, Martin J., et al. ‘A Skeletonization Algorithm for Gradient-Based Optimization’.
6
+ Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), 2023.
7
+
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class Skeletonize(torch.nn.Module):
16
+ """
17
+ Class based on PyTorch's Module class to skeletonize two- or three-dimensional input images
18
+ while being fully compatible with PyTorch's autograd automatic differention engine as proposed in [1].
19
+
20
+ Attributes:
21
+ propabilistic: a Boolean that indicates whether the input image should be binarized using
22
+ the reparametrization trick and straight-through estimator.
23
+ It should always be set to True if non-binary inputs are being provided.
24
+ beta: scale of added logistic noise during the reparametrization trick. If too small, there will not be any learning via
25
+ gradient-based optimization; if too large, the learning is very slow.
26
+ tau: Boltzmann temperature for reparametrization trick.
27
+ simple_point_detection: decides whether simple points should be identified using Boolean characterization of their 26-neighborhood (Boolean) [2]
28
+ or by checking whether the Euler characteristic changes under their deletion (EulerCharacteristic) [3].
29
+ num_iter: number of iterations that each include one end-point check, eight checks for simple points and eight subsequent deletions.
30
+ The number of iterations should be tuned to the type of input image.
31
+
32
+ [1] Martin J. Menten et al. A skeletonization algorithm for gradient-based optimization.
33
+ Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), 2023.
34
+ [2] Gilles Bertrand. A boolean characterization of three- dimensional simple points.
35
+ Pattern recognition letters, 17(2):115-124, 1996.
36
+ [3] Steven Lobregt et al. Three-dimensional skeletonization:principle and algorithm.
37
+ IEEE Transactions on pattern analysis and machine intelligence, 2(1):75-77, 1980.
38
+ """
39
+
40
+ def __init__(self, probabilistic=True, beta=0.33, tau=1.0, simple_point_detection='Boolean', num_iter=5):
41
+
42
+ super(Skeletonize, self).__init__()
43
+
44
+ self.probabilistic = probabilistic
45
+ self.tau = tau
46
+ self.beta = beta
47
+
48
+ self.num_iter = num_iter
49
+ self.endpoint_check = self._single_neighbor_check
50
+ if simple_point_detection == 'Boolean':
51
+ self.simple_check = self._boolean_simple_check
52
+ elif simple_point_detection == 'EulerCharacteristic':
53
+ self.simple_check = self._euler_characteristic_simple_check
54
+ else:
55
+ raise Exception()
56
+
57
+
58
+ def forward(self, img):
59
+
60
+ img = self._prepare_input(img)
61
+
62
+ if self.probabilistic:
63
+ img = self._stochastic_discretization(img)
64
+
65
+ for current_iter in range(self.num_iter):
66
+
67
+ # At each iteration create a new map of the end-points
68
+ is_endpoint = self.endpoint_check(img)
69
+
70
+ # Sub-iterate through eight different subfields
71
+ x_offsets = [0, 1, 0, 1, 0, 1, 0, 1]
72
+ y_offsets = [0, 0, 1, 1, 0, 0, 1, 1]
73
+ z_offsets = [0, 0, 0, 0, 1, 1, 1, 1]
74
+
75
+ for x_offset, y_offset, z_offset in zip(x_offsets, y_offsets, z_offsets):
76
+
77
+ # At each sub-iteration detect all simple points and delete all simple points that are not end-points
78
+ is_simple = self.simple_check(img[:, :, x_offset:, y_offset:, z_offset:])
79
+ deletion_candidates = is_simple * (1 - is_endpoint[:, :, x_offset::2, y_offset::2, z_offset::2])
80
+ img[:, :, x_offset::2, y_offset::2, z_offset::2] = torch.min(img[:, :, x_offset::2, y_offset::2, z_offset::2].clone(), 1 - deletion_candidates)
81
+
82
+ img = self._prepare_output(img)
83
+
84
+ return img
85
+
86
+
87
+
88
+ def _prepare_input(self, img):
89
+ """
90
+ Function to check that the input image is compatible with the subsequent calculations.
91
+ Only two- and three-dimensional images with values between 0 and 1 are supported.
92
+ If the input image is two-dimensional then it is converted into a three-dimensional one for further processing.
93
+ """
94
+
95
+ if img.dim() == 5:
96
+ self.expanded_dims = False
97
+ elif img.dim() == 4:
98
+ self.expanded_dims = True
99
+ img = img.unsqueeze(2)
100
+ else:
101
+ raise Exception("Only two-or three-dimensional images (tensor dimensionality of 4 or 5) are supported as input.")
102
+
103
+ if img.shape[2] == 2 or img.shape[3] == 2 or img.shape[4] == 2 or img.shape[3] == 1 or img.shape[4] == 1:
104
+ raise Exception()
105
+
106
+ if img.min() < 0.0 or img.max() > 1.0:
107
+ raise Exception("Image values must lie between 0 and 1.")
108
+
109
+ img = F.pad(img, (1, 1, 1, 1, 1, 1), value=0)
110
+
111
+ return img
112
+
113
+
114
+ def _stochastic_discretization(self, img):
115
+ """
116
+ Function to binarize the image so that it can be processed by our skeletonization method.
117
+ In order to remain compatible with backpropagation we utilize the reparameterization trick and a straight-through estimator.
118
+ """
119
+
120
+ alpha = (img + 1e-8) / (1.0 - img + 1e-8)
121
+
122
+ uniform_noise = torch.rand_like(img)
123
+ uniform_noise = torch.empty_like(img).uniform_(1e-8, 1 - 1e-8)
124
+ logistic_noise = (torch.log(uniform_noise) - torch.log(1 - uniform_noise))
125
+
126
+ img = torch.sigmoid((torch.log(alpha) + logistic_noise * self.beta) / self.tau)
127
+ img = (img.detach() > 0.5).float() - img.detach() + img
128
+
129
+ return img
130
+
131
+
132
+ def _single_neighbor_check(self, img):
133
+ """
134
+ Function that characterizes points as endpoints if they have a single neighbor or no neighbor at all.
135
+ """
136
+
137
+ img = F.pad(img, (1, 1, 1, 1, 1, 1))
138
+
139
+ # Check that number of ones in twentysix-neighborhood is exactly 0 or 1
140
+ K = torch.tensor([[[1.0, 1.0, 1.0],
141
+ [1.0, 1.0, 1.0],
142
+ [1.0, 1.0, 1.0]],
143
+ [[1.0, 1.0, 1.0],
144
+ [1.0, 0.0, 1.0],
145
+ [1.0, 1.0, 1.0]],
146
+ [[1.0, 1.0, 1.0],
147
+ [1.0, 1.0, 1.0],
148
+ [1.0, 1.0, 1.0]]], device=img.device).view(1, 1, 3, 3, 3)
149
+
150
+ num_twentysix_neighbors = F.conv3d(img, K)
151
+ condition1 = F.hardtanh(-(num_twentysix_neighbors - 2), min_val=0, max_val=1) # 1 or fewer neigbors
152
+
153
+ return condition1
154
+
155
+
156
+ def _boolean_simple_check(self, img):
157
+ """
158
+ Function that identifies simple points using Boolean conditions introduced by Bertrand et al. [1].
159
+ Each Boolean conditions can be assessed via convolutions with a limited number of pre-defined kernels.
160
+ It total, four conditions are checked. If any one is fulfilled, the point is deemed simple.
161
+
162
+ [1] Gilles Bertrand. A boolean characterization of three- dimensional simple points.
163
+ Pattern recognition letters, 17(2):115-124, 1996.
164
+ """
165
+
166
+ img = F.pad(img, (1, 1, 1, 1, 1, 1), value=0)
167
+
168
+ # Condition 1: number of zeros in the six-neighborhood is exactly 1
169
+ K_N6 = torch.tensor([[[0.0, 0.0, 0.0],
170
+ [0.0, 1.0, 0.0],
171
+ [0.0, 0.0, 0.0]],
172
+ [[0.0, 1.0, 0.0],
173
+ [1.0, 0.0, 1.0],
174
+ [0.0, 1.0, 0.0]],
175
+ [[0.0, 0.0, 0.0],
176
+ [0.0, 1.0, 0.0],
177
+ [0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3)
178
+
179
+ num_six_neighbors = F.conv3d(1 - img, K_N6, stride=2)
180
+
181
+ subcondition1a = F.hardtanh(num_six_neighbors, min_val=0, max_val=1) # 1 or more neighbors
182
+ subcondition1b = F.hardtanh(-(num_six_neighbors - 2), min_val=0, max_val=1) # 1 or fewer neighbors
183
+
184
+ condition1 = subcondition1a * subcondition1b
185
+
186
+
187
+ # Condition 2: number of ones in twentysix-neighborhood is exactly 1
188
+ K_N26 = torch.tensor([[[1.0, 1.0, 1.0],
189
+ [1.0, 1.0, 1.0],
190
+ [1.0, 1.0, 1.0]],
191
+ [[1.0, 1.0, 1.0],
192
+ [1.0, 0.0, 1.0],
193
+ [1.0, 1.0, 1.0]],
194
+ [[1.0, 1.0, 1.0],
195
+ [1.0, 1.0, 1.0],
196
+ [1.0, 1.0, 1.0]]], device=img.device).view(1, 1, 3, 3, 3)
197
+
198
+ num_twentysix_neighbors = F.conv3d(img, K_N26, stride=2)
199
+
200
+ subcondition2a = F.hardtanh(num_twentysix_neighbors, min_val=0, max_val=1) # 1 or more neighbors
201
+ subcondition2b = F.hardtanh(-(num_twentysix_neighbors - 2), min_val=0, max_val=1) # 1 or fewer neigbors
202
+
203
+ condition2 = subcondition2a * subcondition2b
204
+
205
+
206
+ # Condition 3: Number of ones in eighteen-neigborhood exactly 1...
207
+ K_N18 = torch.tensor([[[0.0, 1.0, 0.0],
208
+ [1.0, 1.0, 1.0],
209
+ [0.0, 1.0, 0.0]],
210
+ [[1.0, 1.0, 1.0],
211
+ [1.0, 0.0, 1.0],
212
+ [1.0, 1.0, 1.0]],
213
+ [[0.0, 1.0, 0.0],
214
+ [1.0, 1.0, 1.0],
215
+ [0.0, 1.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3)
216
+
217
+ num_eighteen_neighbors = F.conv3d(img, K_N18, stride=2)
218
+
219
+ subcondition3a = F.hardtanh(num_eighteen_neighbors, min_val=0, max_val=1) # 1 or more neighbors
220
+ subcondition3b = F.hardtanh(-(num_eighteen_neighbors - 2), min_val=0, max_val=1) # 1 or fewer neigbors
221
+
222
+ # ... and cell configration B26 does not exist
223
+ K_B26 = torch.tensor([[[1.0, -1.0, 0.0],
224
+ [-1.0, -1.0, 0.0],
225
+ [0.0, 0.0, 0.0]],
226
+ [[-1.0, -1.0, 0.0],
227
+ [-1.0, 0.0, 0.0],
228
+ [0.0, 0.0, 0.0]],
229
+ [[0.0, 0.0, 0.0],
230
+ [0.0, 0.0, 0.0],
231
+ [0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3)
232
+
233
+ B26_1_present = F.relu(F.conv3d(2.0 * img - 1.0, K_B26, stride=2) - 6)
234
+ B26_2_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2]), stride=2) - 6)
235
+ B26_3_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[3]), stride=2) - 6)
236
+ B26_4_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[4]), stride=2) - 6)
237
+ B26_5_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2, 3]), stride=2) - 6)
238
+ B26_6_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2, 4]), stride=2) - 6)
239
+ B26_7_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[3, 4]), stride=2) - 6)
240
+ B26_8_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2, 3, 4]), stride=2) - 6)
241
+ num_B26_cells = B26_1_present + B26_2_present + B26_3_present + B26_4_present + B26_5_present + B26_6_present + B26_7_present + B26_8_present
242
+
243
+ subcondition3c = F.hardtanh(-(num_B26_cells - 1), min_val=0, max_val=1)
244
+
245
+ condition3 = subcondition3a * subcondition3b * subcondition3c
246
+
247
+
248
+ # Condition 4: cell configuration A6 does not exist...
249
+ K_A6 = torch.tensor([[[0.0, 1.0, 0.0],
250
+ [1.0, -1.0, 1.0],
251
+ [0.0, 1.0, 0.0]],
252
+ [[0.0, 0.0, 0.0],
253
+ [0.0, 0.0, 0.0],
254
+ [0.0, 0.0, 0.0]],
255
+ [[0.0, 0.0, 0.0],
256
+ [0.0, 0.0, 0.0],
257
+ [0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3)
258
+
259
+ A6_1_present = F.relu(F.conv3d(2.0 * img - 1.0, K_A6, stride=2) - 4)
260
+ A6_2_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A6, dims=[2, 3]), stride=2) - 4)
261
+ A6_3_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A6, dims=[2, 4]), stride=2) - 4)
262
+ A6_4_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A6, dims=[2]), stride=2) - 4)
263
+ A6_5_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.flip(K_A6, dims=[2]), dims=[2, 3]), stride=2) - 4)
264
+ A6_6_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.flip(K_A6, dims=[2]), dims=[2, 4]), stride=2) - 4)
265
+ num_A6_cells = A6_1_present + A6_2_present + A6_3_present + A6_4_present + A6_5_present + A6_6_present
266
+
267
+ subcondition4a = F.hardtanh(-(num_A6_cells - 1), min_val=0, max_val=1)
268
+
269
+ # ... and cell configuration B26 does not exist...
270
+ K_B26 = torch.tensor([[[1.0, -1.0, 0.0],
271
+ [-1.0, -1.0, 0.0],
272
+ [0.0, 0.0, 0.0]],
273
+ [[-1.0, -1.0, 0.0],
274
+ [-1.0, 0.0, 0.0],
275
+ [0.0, 0.0, 0.0]],
276
+ [[0.0, 0.0, 0.0],
277
+ [0.0, 0.0, 0.0],
278
+ [0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3)
279
+
280
+ B26_1_present = F.relu(F.conv3d(2.0 * img - 1.0, K_B26, stride=2) - 6)
281
+ B26_2_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2]), stride=2) - 6)
282
+ B26_3_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[3]), stride=2) - 6)
283
+ B26_4_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[4]), stride=2) - 6)
284
+ B26_5_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2, 3]), stride=2) - 6)
285
+ B26_6_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2, 4]), stride=2) - 6)
286
+ B26_7_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[3, 4]), stride=2) - 6)
287
+ B26_8_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2, 3, 4]), stride=2) - 6)
288
+ num_B26_cells = B26_1_present + B26_2_present + B26_3_present + B26_4_present + B26_5_present + B26_6_present + B26_7_present + B26_8_present
289
+
290
+ subcondition4b = F.hardtanh(-(num_B26_cells - 1), min_val=0, max_val=1)
291
+
292
+ # ... and cell configuration B18 does not exist...
293
+ K_B18 = torch.tensor([[[0.0, 1.0, 0.0],
294
+ [-1.0, -1.0, -1.0],
295
+ [0.0, 0.0, 0.0]],
296
+ [[-1.0, -1.0, -1.0],
297
+ [-1.0, 0.0, -1.0],
298
+ [0.0, 0.0, 0.0]],
299
+ [[0.0, 0.0, 0.0],
300
+ [0.0, 0.0, 0.0],
301
+ [0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3)
302
+
303
+ B18_1_present = F.relu(F.conv3d(2.0 * img - 1.0, K_B18, stride=2) - 8)
304
+ B18_2_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_B18, dims=[2, 4]), stride=2) - 8)
305
+ B18_3_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_B18, dims=[2, 4], k=2), stride=2) - 8)
306
+ B18_4_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_B18, dims=[2, 4], k=3), stride=2) - 8)
307
+ B18_5_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_B18, dims=[3, 4]), stride=2) - 8)
308
+ B18_6_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_B18, dims=[3, 4]), dims=[2, 4]), stride=2) - 8)
309
+ B18_7_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_B18, dims=[3, 4]), dims=[2, 4], k=2), stride=2) - 8)
310
+ B18_8_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_B18, dims=[3, 4]), dims=[2, 4], k=3), stride=2) - 8)
311
+ B18_9_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_B18, dims=[3, 4], k=2), stride=2) - 8)
312
+ B18_10_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_B18, dims=[3, 4], k=2), dims=[2, 4]), stride=2) - 8)
313
+ B18_11_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_B18, dims=[3, 4], k=2), dims=[2, 4], k=2), stride=2) - 8)
314
+ B18_12_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_B18, dims=[3, 4], k=2), dims=[2, 4], k=3), stride=2) - 8)
315
+ num_B18_cells = B18_1_present + B18_2_present + B18_3_present + B18_4_present + B18_5_present + B18_6_present + B18_7_present + B18_8_present + B18_9_present + B18_10_present + B18_11_present + B18_12_present
316
+
317
+ subcondition4c = F.hardtanh(-(num_B18_cells - 1), min_val=0, max_val=1)
318
+
319
+ # ... and the number of zeros in the six-neighborhood minus the number of A18 cell configurations plus the number of A26 cell configurations is exactly one
320
+ K_N6 = torch.tensor([[[0.0, 0.0, 0.0],
321
+ [0.0, 1.0, 0.0],
322
+ [0.0, 0.0, 0.0]],
323
+ [[0.0, 1.0, 0.0],
324
+ [1.0, 0.0, 1.0],
325
+ [0.0, 1.0, 0.0]],
326
+ [[0.0, 0.0, 0.0],
327
+ [0.0, 1.0, 0.0],
328
+ [0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3)
329
+
330
+ num_six_neighbors = F.conv3d(1-img, K_N6, stride=2)
331
+
332
+ K_A18 = torch.tensor([[[0.0, -1.0, 0.0],
333
+ [0.0, -1.0, 0.0],
334
+ [0.0, 0.0, 0.0]],
335
+ [[0.0, -1.0, 0.0],
336
+ [0.0, 0.0, 0.0],
337
+ [0.0, 0.0, 0.0]],
338
+ [[0.0, 0.0, 0.0],
339
+ [0.0, 0.0, 0.0],
340
+ [0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3)
341
+
342
+ A18_1_present = F.relu(F.conv3d(2.0 * img - 1.0, K_A18, stride=2) - 2)
343
+ A18_2_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A18, dims=[2, 4]), stride=2) - 2)
344
+ A18_3_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A18, dims=[2, 4], k=2), stride=2) - 2)
345
+ A18_4_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A18, dims=[2, 4], k=3), stride=2) - 2)
346
+ A18_5_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A18, dims=[3, 4]), stride=2) - 2)
347
+ A18_6_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_A18, dims=[3, 4]), dims=[2, 4]), stride=2) - 2)
348
+ A18_7_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_A18, dims=[3, 4]), dims=[2, 4], k=2), stride=2) - 2)
349
+ A18_8_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_A18, dims=[3, 4]), dims=[2, 4], k=3), stride=2) - 2)
350
+ A18_9_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A18, dims=[3, 4], k=2), stride=2) - 2)
351
+ A18_10_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_A18, dims=[3, 4], k=2), dims=[2, 4]), stride=2) - 2)
352
+ A18_11_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_A18, dims=[3, 4], k=2), dims=[2, 4], k=2), stride=2) - 2)
353
+ A18_12_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_A18, dims=[3, 4], k=2), dims=[2, 4], k=3), stride=2) - 2)
354
+ num_A18_cells = A18_1_present + A18_2_present + A18_3_present + A18_4_present + A18_5_present + A18_6_present + A18_7_present + A18_8_present + A18_9_present + A18_10_present + A18_11_present + A18_12_present
355
+
356
+ K_A26 = torch.tensor([[[-1.0, -1.0, 0.0],
357
+ [-1.0, -1.0, 0.0],
358
+ [0.0, 0.0, 0.0]],
359
+ [[-1.0, -1.0, 0.0],
360
+ [-1.0, 0.0, 0.0],
361
+ [0.0, 0.0, 0.0]],
362
+ [[0.0, 0.0, 0.0],
363
+ [0.0, 0.0, 0.0],
364
+ [0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3)
365
+
366
+ A26_1_present = F.relu(F.conv3d(2.0 * img - 1.0, K_A26, stride=2) - 6)
367
+ A26_2_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[2]), stride=2) - 6)
368
+ A26_3_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[3]), stride=2) - 6)
369
+ A26_4_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[4]), stride=2) - 6)
370
+ A26_5_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[2, 3]), stride=2) - 6)
371
+ A26_6_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[2, 4]), stride=2) - 6)
372
+ A26_7_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[3, 4]), stride=2) - 6)
373
+ A26_8_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[2, 3, 4]), stride=2) - 6)
374
+ num_A26_cells = A26_1_present + A26_2_present + A26_3_present + A26_4_present + A26_5_present + A26_6_present + A26_7_present + A26_8_present
375
+
376
+ subcondition4d = F.hardtanh(num_six_neighbors - num_A18_cells + num_A26_cells, min_val=0, max_val=1) # 1 or more configurations
377
+ subcondition4e = F.hardtanh(-(num_six_neighbors - num_A18_cells + num_A26_cells - 2), min_val=0, max_val=1) # 1 or fewer configurations
378
+
379
+ condition4 = subcondition4a * subcondition4b * subcondition4c * subcondition4d * subcondition4e
380
+
381
+ # If any of the four conditions is fulfilled the point is simple
382
+ combined = torch.cat([condition1, condition2, condition3, condition4], dim=1)
383
+ is_simple = torch.amax(combined, dim=1, keepdim=True)
384
+
385
+ return is_simple
386
+
387
+
388
+ # Specifically designed to be used with the eight-subfield iterative scheme from above.
389
+ def _euler_characteristic_simple_check(self, img):
390
+ """
391
+ Function that identifies simple points by assessing whether the Euler characteristic changes when deleting it [1].
392
+ In order to calculate the Euler characteristic, the amount of vertices, edges, faces and octants are counted using convolutions with pre-defined kernels.
393
+ The function is meant to be used in combination with the subfield-based iterative scheme employed in the forward function.
394
+
395
+ [1] Steven Lobregt et al. Three-dimensional skeletonization:principle and algorithm.
396
+ IEEE Transactions on pattern analysis and machine intelligence, 2(1):75-77, 1980.
397
+ """
398
+
399
+ img = F.pad(img, (1, 1, 1, 1, 1, 1), value=0)
400
+
401
+ # Create masked version of the image where the center of 26-neighborhoods is changed to zero
402
+ mask = torch.ones_like(img)
403
+ mask[:, :, 1::2, 1::2, 1::2] = 0
404
+ masked_img = img.clone() * mask
405
+
406
+ # Count vertices
407
+ vertices = F.relu(-(2.0 * img - 1.0))
408
+ num_vertices = F.avg_pool3d(vertices, (3, 3, 3), stride=2) * 27
409
+
410
+ masked_vertices = F.relu(-(2.0 * masked_img - 1.0))
411
+ num_masked_vertices = F.avg_pool3d(masked_vertices, (3, 3, 3), stride=2) * 27
412
+
413
+ # Count edges
414
+ K_ud_edge = torch.tensor([0.5, 0.5], device=img.device).view(1, 1, 2, 1, 1)
415
+ K_ns_edge = torch.tensor([0.5, 0.5], device=img.device).view(1, 1, 1, 2, 1)
416
+ K_we_edge = torch.tensor([0.5, 0.5], device=img.device).view(1, 1, 1, 1, 2)
417
+
418
+ ud_edges = F.relu(F.conv3d(-(2.0 * img - 1.0), K_ud_edge))
419
+ num_ud_edges = F.avg_pool3d(ud_edges, (2, 3, 3), stride=2) * 18
420
+ ns_edges = F.relu(F.conv3d(-(2.0 * img - 1.0), K_ns_edge))
421
+ num_ns_edges = F.avg_pool3d(ns_edges, (3, 2, 3), stride=2) * 18
422
+ we_edges = F.relu(F.conv3d(-(2.0 * img - 1.0), K_we_edge))
423
+ num_we_edges = F.avg_pool3d(we_edges, (3, 3, 2), stride=2) * 18
424
+ num_edges = num_ud_edges + num_ns_edges + num_we_edges
425
+
426
+ masked_ud_edges = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_ud_edge))
427
+ num_masked_ud_edges = F.avg_pool3d(masked_ud_edges, (2, 3, 3), stride=2) * 18
428
+ masked_ns_edges = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_ns_edge))
429
+ num_masked_ns_edges = F.avg_pool3d(masked_ns_edges, (3, 2, 3), stride=2) * 18
430
+ masked_we_edges = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_we_edge))
431
+ num_masked_we_edges = F.avg_pool3d(masked_we_edges, (3, 3, 2), stride=2) * 18
432
+ num_masked_edges = num_masked_ud_edges + num_masked_ns_edges + num_masked_we_edges
433
+
434
+ # Count faces
435
+ K_ud_face = torch.tensor([[0.25, 0.25], [0.25, 0.25]], device=img.device).view(1, 1, 1, 2, 2)
436
+ K_ns_face = torch.tensor([[0.25, 0.25], [0.25, 0.25]], device=img.device).view(1, 1, 2, 1, 2)
437
+ K_we_face = torch.tensor([[0.25, 0.25], [0.25, 0.25]], device=img.device).view(1, 1, 2, 2, 1)
438
+
439
+ ud_faces = F.relu(F.conv3d(-(2.0 * img - 1.0), K_ud_face) - 0.5) * 2
440
+ num_ud_faces = F.avg_pool3d(ud_faces, (3, 2, 2), stride=2) * 12
441
+ ns_faces = F.relu(F.conv3d(-(2.0 * img - 1.0), K_ns_face) - 0.5) * 2
442
+ num_ns_faces = F.avg_pool3d(ns_faces, (2, 3, 2), stride=2) * 12
443
+ we_faces = F.relu(F.conv3d(-(2.0 * img - 1.0), K_we_face) - 0.5) * 2
444
+ num_we_faces = F.avg_pool3d(we_faces, (2, 2, 3), stride=2) * 12
445
+ num_faces = num_ud_faces + num_ns_faces + num_we_faces
446
+
447
+ masked_ud_faces = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_ud_face) - 0.5) * 2
448
+ num_masked_ud_faces = F.avg_pool3d(masked_ud_faces, (3, 2, 2), stride=2) * 12
449
+ masked_ns_faces = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_ns_face) - 0.5) * 2
450
+ num_masked_ns_faces = F.avg_pool3d(masked_ns_faces, (2, 3, 2), stride=2) * 12
451
+ masked_we_faces = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_we_face) - 0.5) * 2
452
+ num_masked_we_faces = F.avg_pool3d(masked_we_faces, (2, 2, 3), stride=2) * 12
453
+ num_masked_faces = num_masked_ud_faces + num_masked_ns_faces + num_masked_we_faces
454
+
455
+ # Count octants
456
+ K_octants = torch.tensor([[[0.125, 0.125], [0.125, 0.125]], [[0.125, 0.125], [0.125, 0.125]]], device=img.device).view(1, 1, 2, 2, 2)
457
+
458
+ octants = F.relu(F.conv3d(-(2.0 * img - 1.0), K_octants) - 0.75) * 4
459
+ num_octants = F.avg_pool3d(octants, (2, 2, 2), stride=2) * 8
460
+
461
+ masked_octants = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_octants) - 0.75) * 4
462
+ num_masked_octants = F.avg_pool3d(masked_octants, (2, 2, 2), stride=2) * 8
463
+
464
+ # Combined number of vertices, edges, faces and octants to calculate the euler characteristic
465
+ euler_characteristic = num_vertices - num_edges + num_faces - num_octants
466
+ masked_euler_characteristic = num_masked_vertices - num_masked_edges + num_masked_faces - num_masked_octants
467
+
468
+ # If the Euler characteristic is unchanged after switching a point from 1 to 0 this indicates that the point is simple
469
+ euler_change = F.hardtanh(torch.abs(masked_euler_characteristic - euler_characteristic), min_val=0, max_val=1)
470
+ is_simple = 1 - euler_change
471
+ is_simple = (is_simple.detach() > 0.5).float() - is_simple.detach() + is_simple
472
+
473
+ return is_simple
474
+
475
+
476
+ def _prepare_output(self, img):
477
+ """
478
+ Function that removes the padding and dimensions added by _prepare_input function.
479
+ """
480
+
481
+ img = img[:, :, 1:-1, 1:-1, 1:-1]
482
+
483
+ if self.expanded_dims:
484
+ img = torch.squeeze(img, dim=2)
485
+
486
+ return img
util/tools.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Author: Jintao Li
3
+ Date: 2022-05-30 16:42:14
4
+ LastEditors: Jintao Li
5
+ LastEditTime: 2022-07-11 23:05:53
6
+ 2022 by CIG.
7
+ '''
8
+
9
+ import os, shutil
10
+ import yaml, argparse
11
+ from sklearn.metrics import confusion_matrix
12
+ import numpy as np
13
+ import torch
14
+
15
+
16
+ def accuracy(output, target):
17
+ '''
18
+ output: [N, num_classes, ...], torch.float
19
+ target: [N, ...], torch.int
20
+ '''
21
+ output = output.argmax(dim=1).flatten().detach().cpu().numpy()
22
+ target = target.flatten().detach().cpu().numpy()
23
+ return pixel_acc(output, target), _miou(output, target)
24
+
25
+
26
+ def pixel_acc(output, target):
27
+ r"""
28
+ 计算像素准确率 (Pixel Accuracy, PA)
29
+ $$ PA = \frac{\sum_{i=0}^k p_{ii}}
30
+ {\sum_{i=0}^k \sum_{j=0}^k p_{ij}} $$ and
31
+ $n_class = k+1$
32
+ Parameters:
33
+ -----------
34
+ shape: [N, ], (use flatten() function)
35
+ return:
36
+ ----------
37
+ - PA
38
+ """
39
+ assert output.shape == target.shape, "shapes must be same"
40
+ cm = confusion_matrix(target, output)
41
+ return np.diag(cm).sum() / cm.sum()
42
+
43
+
44
+ def _miou(output, target):
45
+ r"""
46
+ 计算均值交并比 MIoU (Mean Intersection over Union)
47
+ $$ MIoU = \frac{1}{k+1} \sum_{i=0}^k \frac{p_{ii}}
48
+ {\sum_{j=0}^k p_{ij} + \sum_{j=0}^k p_{ji} - p_{ii}} $$
49
+ Parameters:
50
+ output, target: [N, ]
51
+ return:
52
+ MIoU
53
+ """
54
+ assert output.shape == target.shape, "shapes must be same"
55
+ cm = confusion_matrix(target, output)
56
+ intersection = np.diag(cm)
57
+ union = np.sum(cm, 1) + np.sum(cm, 0) - np.diag(cm)
58
+ iou = intersection / union
59
+ miou = np.nanmean(iou)
60
+
61
+ return miou
62
+
63
+
64
+ def yaml_config_hook(config_file: str) -> argparse.Namespace:
65
+ """
66
+ 加载yaml文件里面的参数配置, 并生成argparse形式的参数集合
67
+ """
68
+ with open(config_file) as f:
69
+ cfg = yaml.safe_load(f)
70
+ for d in cfg.get("defaults", []):
71
+ config_dir, cf = d.popitem()
72
+ cf = os.path.join(os.path.dirname(config_file), config_dir,
73
+ cf + ".yaml")
74
+ with open(cf) as f:
75
+ l = yaml.safe_load(f)
76
+ cfg.update(l)
77
+
78
+ if "defaults" in cfg.keys():
79
+ del cfg["defaults"]
80
+
81
+ parser = argparse.ArgumentParser()
82
+ for k, v in cfg.items():
83
+ parser.add_argument(f"--{k}", default=v, type=type(v))
84
+ args = parser.parse_args()
85
+
86
+ return args
87
+
88
+
89
+ def backup_code(work_dir, back_dir, exceptions=[], include=[]):
90
+ r"""
91
+ 备份本次运行的代码到指定目录下, 并排除某些文件和目录
92
+
93
+ Args:
94
+ work_dir: 工作目录, i.e. 需要备份的代码
95
+ back_dir: 目标目录.备份代码放置的目录
96
+ exception (list): 被排除的目录和以指定后缀结尾的文件, 默认的有
97
+ ["__pycache__", ".pyc", ".dat", "backup", ".vscode"]
98
+ include (list): 某些必须被备份的文件,该文件可能在exception里面
99
+ """
100
+ _exp = [
101
+ "*__pycache__*", "*.pyc", "*.dat", "backup", ".vscode", "*.log",
102
+ "*log*"
103
+ ]
104
+ exceptions = exceptions + _exp
105
+
106
+ # if not os.path.exists(back_dir):
107
+ os.makedirs(back_dir, exist_ok=True)
108
+
109
+ shutil.copytree(work_dir,
110
+ back_dir + 'code/',
111
+ ignore=shutil.ignore_patterns(*exceptions),
112
+ dirs_exist_ok=True)
113
+
114
+ for f in include:
115
+ shutil.copyfile(os.path.join(work_dir, f),
116
+ os.path.join(back_dir + 'code', f))
117
+
118
+
119
+ def list_files(path, full=False):
120
+ r"""
121
+ 递归列出目录下所有的文件,包括子目录下的文件
122
+ """
123
+ out = []
124
+ for f in os.listdir(path):
125
+ fname = os.path.join(path, f)
126
+ if os.path.isdir(fname):
127
+ fname = list_files(fname)
128
+ out += [os.path.join(f, i) for i in fname]
129
+ else:
130
+ out.append(f)
131
+ if full:
132
+ out = [os.path.join(path, i) for i in out]
133
+ return out
134
+
135
+
136
+ if __name__ == "__main__":
137
+ output = torch.randn(4, 2, 6, 6)
138
+ target = torch.randn(4, 2, 6, 6)
139
+ # output = output.cuda()
140
+ # target = target.cuda()
141
+ target = target.argmax(1)
142
+
143
+ accuracy(output, target)
util/variable_pos_embed.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Variable size position embedding utils for handling different image dimensions
8
+ # --------------------------------------------------------
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
+
14
+
15
+ def get_2d_sincos_pos_embed_variable(embed_dim, grid_h, grid_w, cls_token=False):
16
+ """
17
+ Create 2D sine-cosine position embeddings for variable grid sizes
18
+
19
+ Args:
20
+ embed_dim: embedding dimension
21
+ grid_h: height of the grid (number of patches in height)
22
+ grid_w: width of the grid (number of patches in width)
23
+ cls_token: whether to include class token
24
+
25
+ Returns:
26
+ pos_embed: [grid_h*grid_w, embed_dim] or [1+grid_h*grid_w, embed_dim] (w/ or w/o cls_token)
27
+ """
28
+ grid_h_coords = np.arange(grid_h, dtype=np.float32)
29
+ grid_w_coords = np.arange(grid_w, dtype=np.float32)
30
+ grid = np.meshgrid(grid_w_coords, grid_h_coords) # here w goes first
31
+ grid = np.stack(grid, axis=0)
32
+
33
+ grid = grid.reshape([2, 1, grid_h, grid_w])
34
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
35
+ if cls_token:
36
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
37
+ return pos_embed
38
+
39
+
40
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
41
+ assert embed_dim % 2 == 0
42
+
43
+ # use half of dimensions to encode grid_h
44
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
45
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
46
+
47
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
48
+ return emb
49
+
50
+
51
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
52
+ """
53
+ embed_dim: output dimension for each position
54
+ pos: a list of positions to be encoded: size (M,)
55
+ out: (M, D)
56
+ """
57
+ assert embed_dim % 2 == 0
58
+ omega = np.arange(embed_dim // 2, dtype=np.float)
59
+ omega /= embed_dim / 2.
60
+ omega = 1. / 10000**omega # (D/2,)
61
+
62
+ pos = pos.reshape(-1) # (M,)
63
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
64
+
65
+ emb_sin = np.sin(out) # (M, D/2)
66
+ emb_cos = np.cos(out) # (M, D/2)
67
+
68
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
69
+ return emb
70
+
71
+
72
+ def interpolate_pos_embed_variable(original_pos_embed, target_h, target_w, cls_token=True):
73
+ """
74
+ Interpolate position embeddings for arbitrary target sizes
75
+
76
+ Args:
77
+ original_pos_embed: original positional embeddings [1, N, D]
78
+ target_h: target height in patches
79
+ target_w: target width in patches
80
+ cls_token: whether the first token is a class token
81
+
82
+ Returns:
83
+ interpolated_pos_embed: [1, target_h*target_w + cls_token, D]
84
+ """
85
+ embed_dim = original_pos_embed.shape[-1]
86
+
87
+ if cls_token:
88
+ class_pos_embed = original_pos_embed[:, 0:1] # [1, 1, D]
89
+ patch_pos_embed = original_pos_embed[:, 1:] # [1, N-1, D]
90
+ orig_num_patches = patch_pos_embed.shape[1]
91
+ else:
92
+ class_pos_embed = None
93
+ patch_pos_embed = original_pos_embed
94
+ orig_num_patches = patch_pos_embed.shape[1]
95
+
96
+ # Determine original grid size (assume square for original)
97
+ orig_h = orig_w = int(np.sqrt(orig_num_patches))
98
+
99
+ if orig_h * orig_w != orig_num_patches:
100
+ raise ValueError(f"Original number of patches {orig_num_patches} is not a perfect square")
101
+
102
+ # Reshape to spatial dimensions
103
+ patch_pos_embed = patch_pos_embed.reshape(1, orig_h, orig_w, embed_dim)
104
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) # [1, D, orig_h, orig_w]
105
+
106
+ # Interpolate to target size
107
+ patch_pos_embed = F.interpolate(
108
+ patch_pos_embed,
109
+ size=(target_h, target_w),
110
+ mode='bicubic',
111
+ align_corners=False
112
+ )
113
+
114
+ # Reshape back to token sequence
115
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1) # [1, target_h, target_w, D]
116
+ patch_pos_embed = patch_pos_embed.flatten(1, 2) # [1, target_h*target_w, D]
117
+
118
+ if cls_token:
119
+ new_pos_embed = torch.cat([class_pos_embed, patch_pos_embed], dim=1)
120
+ else:
121
+ new_pos_embed = patch_pos_embed
122
+
123
+ return new_pos_embed
124
+
125
+
126
+ def create_variable_pos_embed(embed_dim, height_patches, width_patches, cls_token=True):
127
+ """
128
+ Create positional embeddings for specific patch grid dimensions
129
+
130
+ Args:
131
+ embed_dim: embedding dimension
132
+ height_patches: number of patches in height
133
+ width_patches: number of patches in width
134
+ cls_token: whether to include class token
135
+
136
+ Returns:
137
+ pos_embed: positional embeddings tensor
138
+ """
139
+ pos_embed_np = get_2d_sincos_pos_embed_variable(
140
+ embed_dim, height_patches, width_patches, cls_token=cls_token
141
+ )
142
+ pos_embed = torch.from_numpy(pos_embed_np).float().unsqueeze(0)
143
+ return pos_embed