Spaces:
Build error
Build error
Anirudh Bhalekar
commited on
Commit
·
a3f0d6c
1
Parent(s):
6e87dc7
added models and util folder
Browse files- models_Facies.py +397 -0
- models_Fault.py +327 -0
- util/__pycache__/datasets.cpython-311.pyc +0 -0
- util/__pycache__/datasets.cpython-312.pyc +0 -0
- util/__pycache__/datasets.cpython-36.pyc +0 -0
- util/__pycache__/datasets.cpython-37.pyc +0 -0
- util/__pycache__/lars.cpython-36.pyc +0 -0
- util/__pycache__/lr_decay.cpython-311.pyc +0 -0
- util/__pycache__/lr_decay.cpython-312.pyc +0 -0
- util/__pycache__/lr_decay.cpython-36.pyc +0 -0
- util/__pycache__/lr_decay.cpython-37.pyc +0 -0
- util/__pycache__/lr_sched.cpython-311.pyc +0 -0
- util/__pycache__/lr_sched.cpython-312.pyc +0 -0
- util/__pycache__/lr_sched.cpython-36.pyc +0 -0
- util/__pycache__/lr_sched.cpython-37.pyc +0 -0
- util/__pycache__/metrics.cpython-36.pyc +0 -0
- util/__pycache__/misc.cpython-311.pyc +0 -0
- util/__pycache__/misc.cpython-312.pyc +0 -0
- util/__pycache__/misc.cpython-36.pyc +0 -0
- util/__pycache__/misc.cpython-37.pyc +0 -0
- util/__pycache__/msssim.cpython-311.pyc +0 -0
- util/__pycache__/msssim.cpython-312.pyc +0 -0
- util/__pycache__/msssim.cpython-36.pyc +0 -0
- util/__pycache__/msssim.cpython-37.pyc +0 -0
- util/__pycache__/pos_embed.cpython-311.pyc +0 -0
- util/__pycache__/pos_embed.cpython-312.pyc +0 -0
- util/__pycache__/pos_embed.cpython-36.pyc +0 -0
- util/__pycache__/pos_embed.cpython-37.pyc +0 -0
- util/__pycache__/size_aware_batching.cpython-312.pyc +0 -0
- util/__pycache__/skeletonize.cpython-312.pyc +0 -0
- util/__pycache__/tools.cpython-311.pyc +0 -0
- util/__pycache__/tools.cpython-312.pyc +0 -0
- util/__pycache__/tools.cpython-36.pyc +0 -0
- util/__pycache__/tools.cpython-37.pyc +0 -0
- util/__pycache__/variable_pos_embed.cpython-312.pyc +0 -0
- util/crop.py +42 -0
- util/datasets.py +599 -0
- util/lars.py +47 -0
- util/lr_decay.py +76 -0
- util/lr_sched.py +21 -0
- util/metrics.py +90 -0
- util/misc.py +340 -0
- util/msssim.py +146 -0
- util/pos_embed.py +104 -0
- util/pos_embedtest.py +127 -0
- util/post_processing.py +305 -0
- util/size_aware_batching.py +251 -0
- util/skeletonize.py +486 -0
- util/tools.py +143 -0
- util/variable_pos_embed.py +143 -0
models_Facies.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 9 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
from functools import partial
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import timm.models.vision_transformer
|
| 18 |
+
import numpy as np
|
| 19 |
+
from util.pos_embed import get_2d_sincos_pos_embed
|
| 20 |
+
from util.variable_pos_embed import interpolate_pos_embed_variable
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class FlexiblePatchEmbed(nn.Module):
|
| 24 |
+
""" 2D Image to Patch Embedding that handles variable input sizes """
|
| 25 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, bias=True):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.img_size = img_size
|
| 28 |
+
self.patch_size = patch_size
|
| 29 |
+
self.in_chans = in_chans
|
| 30 |
+
self.embed_dim = embed_dim
|
| 31 |
+
|
| 32 |
+
self.num_patches = (img_size // patch_size) ** 2 # default number of patches
|
| 33 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
| 34 |
+
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
B, C, H, W = x.shape
|
| 37 |
+
# Calculate number of patches dynamically
|
| 38 |
+
self.num_patches = (H // self.patch_size) * (W // self.patch_size)
|
| 39 |
+
x = self.proj(x).flatten(2).transpose(1, 2) # BCHW -> BNC
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
|
| 43 |
+
""" Vision Transformer with support for variable image sizes and adaptive positional embeddings
|
| 44 |
+
"""
|
| 45 |
+
def __init__(self, global_pool=False, **kwargs):
|
| 46 |
+
super(VisionTransformer, self).__init__(**kwargs)
|
| 47 |
+
|
| 48 |
+
self.global_pool = global_pool
|
| 49 |
+
self.decoder = VIT_MLAHead(mla_channels=self.embed_dim,num_classes=self.num_classes)
|
| 50 |
+
|
| 51 |
+
self.segmentation_head = SegmentationHead(
|
| 52 |
+
in_channels=16,
|
| 53 |
+
out_channels=self.num_classes,
|
| 54 |
+
kernel_size=3,
|
| 55 |
+
)
|
| 56 |
+
if self.global_pool:
|
| 57 |
+
norm_layer = kwargs['norm_layer']
|
| 58 |
+
embed_dim = kwargs['embed_dim']
|
| 59 |
+
self.fc_norm = norm_layer(embed_dim)
|
| 60 |
+
del self.norm # remove the original norm
|
| 61 |
+
|
| 62 |
+
def interpolate_pos_encoding(self, x, h, w):
|
| 63 |
+
"""
|
| 64 |
+
Interpolate positional embeddings for arbitrary input sizes
|
| 65 |
+
"""
|
| 66 |
+
npatch = x.shape[1] - 1 # subtract 1 for cls token
|
| 67 |
+
N = self.pos_embed.shape[1] - 1 # original number of patches
|
| 68 |
+
|
| 69 |
+
if npatch == N and h == w:
|
| 70 |
+
return self.pos_embed
|
| 71 |
+
|
| 72 |
+
# Use the new variable position embedding utility
|
| 73 |
+
return interpolate_pos_embed_variable(self.pos_embed, h, w, cls_token=True)
|
| 74 |
+
|
| 75 |
+
def forward_features(self, x):
|
| 76 |
+
B, C, H, W = x.shape
|
| 77 |
+
|
| 78 |
+
# Handle padding for non-16-divisible images
|
| 79 |
+
patch_size = self.patch_embed.patch_size
|
| 80 |
+
pad_h = (patch_size - H % patch_size) % patch_size
|
| 81 |
+
pad_w = (patch_size - W % patch_size) % patch_size
|
| 82 |
+
|
| 83 |
+
if pad_h > 0 or pad_w > 0:
|
| 84 |
+
x = F.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
|
| 85 |
+
H_padded, W_padded = H + pad_h, W + pad_w
|
| 86 |
+
else:
|
| 87 |
+
H_padded, W_padded = H, W
|
| 88 |
+
|
| 89 |
+
# Extract patches
|
| 90 |
+
x = self.patch_embed(x)
|
| 91 |
+
_H, _W = H_padded // patch_size, W_padded // patch_size
|
| 92 |
+
|
| 93 |
+
# Add class token
|
| 94 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 95 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 96 |
+
|
| 97 |
+
# Add interpolated positional embeddings
|
| 98 |
+
pos_embed = self.interpolate_pos_encoding(x, _H, _W)
|
| 99 |
+
x = x + pos_embed
|
| 100 |
+
x = self.pos_drop(x)
|
| 101 |
+
|
| 102 |
+
featureskip = []
|
| 103 |
+
featureskipnum = 1
|
| 104 |
+
for blk in self.blocks:
|
| 105 |
+
x = blk(x)
|
| 106 |
+
if featureskipnum % (len(self.blocks) // 4) == 0:
|
| 107 |
+
featureskip.append(x[:, 1:, :]) # exclude cls token
|
| 108 |
+
featureskipnum += 1
|
| 109 |
+
|
| 110 |
+
# Pass original dimensions for proper reconstruction
|
| 111 |
+
x = self.decoder(featureskip[0], featureskip[1], featureskip[2], featureskip[3],
|
| 112 |
+
h=_H, w=_W, target_h=H, target_w=W)
|
| 113 |
+
return x
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
x = self.forward_features(x)
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
class Conv2dReLU(nn.Sequential):
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
in_channels,
|
| 123 |
+
out_channels,
|
| 124 |
+
kernel_size,
|
| 125 |
+
padding=0,
|
| 126 |
+
stride=1,
|
| 127 |
+
use_batchnorm=True,
|
| 128 |
+
):
|
| 129 |
+
conv = nn.Conv2d(
|
| 130 |
+
in_channels,
|
| 131 |
+
out_channels,
|
| 132 |
+
kernel_size,
|
| 133 |
+
stride=stride,
|
| 134 |
+
padding=padding,
|
| 135 |
+
bias=not (use_batchnorm),
|
| 136 |
+
)
|
| 137 |
+
relu = nn.ReLU(inplace=True)
|
| 138 |
+
|
| 139 |
+
bn = nn.BatchNorm2d(out_channels)
|
| 140 |
+
|
| 141 |
+
super(Conv2dReLU, self).__init__(conv, bn, relu)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class DecoderBlock(nn.Module):
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
in_channels,
|
| 148 |
+
out_channels,
|
| 149 |
+
skip_channels=0,
|
| 150 |
+
use_batchnorm=True,
|
| 151 |
+
):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.conv1 = Conv2dReLU(
|
| 154 |
+
in_channels + skip_channels,
|
| 155 |
+
out_channels,
|
| 156 |
+
kernel_size=3,
|
| 157 |
+
padding=1,
|
| 158 |
+
use_batchnorm=use_batchnorm,
|
| 159 |
+
)
|
| 160 |
+
self.conv2 = Conv2dReLU(
|
| 161 |
+
out_channels,
|
| 162 |
+
out_channels,
|
| 163 |
+
kernel_size=3,
|
| 164 |
+
padding=1,
|
| 165 |
+
use_batchnorm=use_batchnorm,
|
| 166 |
+
)
|
| 167 |
+
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
|
| 168 |
+
|
| 169 |
+
def forward(self, x, skip=None):
|
| 170 |
+
# print(x.shape,skip.shape)
|
| 171 |
+
if skip is not None:
|
| 172 |
+
x = torch.cat([x, skip], dim=1)
|
| 173 |
+
x = self.up(x)
|
| 174 |
+
x = self.conv1(x)
|
| 175 |
+
x = self.conv2(x)
|
| 176 |
+
return x
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class SegmentationHead(nn.Sequential):
|
| 180 |
+
|
| 181 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
|
| 182 |
+
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
|
| 183 |
+
upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
|
| 184 |
+
super().__init__(conv2d, upsampling)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class DecoderCup(nn.Module):
|
| 188 |
+
def __init__(self):
|
| 189 |
+
super().__init__()
|
| 190 |
+
# self.config = config
|
| 191 |
+
head_channels = 512
|
| 192 |
+
self.conv_more = Conv2dReLU(
|
| 193 |
+
1024,
|
| 194 |
+
head_channels,
|
| 195 |
+
kernel_size=3,
|
| 196 |
+
padding=1,
|
| 197 |
+
use_batchnorm=True,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
decoder_channels = (256,128,64,16)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
in_channels = [head_channels] + list(decoder_channels[:-1])
|
| 204 |
+
out_channels = decoder_channels
|
| 205 |
+
|
| 206 |
+
# if self.config.n_skip != 0:
|
| 207 |
+
# skip_channels = self.config.skip_channels
|
| 208 |
+
# for i in range(4-self.config.n_skip): # re-select the skip channels according to n_skip
|
| 209 |
+
# skip_channels[3-i]=0
|
| 210 |
+
# else:
|
| 211 |
+
# skip_channels=[0,0,0,0]
|
| 212 |
+
skip_channels=[512,256,128,64]
|
| 213 |
+
self.conv_feature1 = Conv2dReLU(1024,skip_channels[0],kernel_size=3,padding=1,use_batchnorm=True)
|
| 214 |
+
self.conv_feature2 = Conv2dReLU(1024,skip_channels[1],kernel_size=3,padding=1,use_batchnorm=True)
|
| 215 |
+
self.up2 = nn.UpsamplingBilinear2d(scale_factor=2)
|
| 216 |
+
self.conv_feature3 = Conv2dReLU(1024,skip_channels[2],kernel_size=3,padding=1,use_batchnorm=True)
|
| 217 |
+
self.up3 = nn.UpsamplingBilinear2d(scale_factor=4)
|
| 218 |
+
self.conv_feature4 = Conv2dReLU(1024,skip_channels[3],kernel_size=3,padding=1,use_batchnorm=True)
|
| 219 |
+
self.up4 = nn.UpsamplingBilinear2d(scale_factor=8)
|
| 220 |
+
|
| 221 |
+
# skip_channels=[128,64,32,8]
|
| 222 |
+
blocks = [
|
| 223 |
+
DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
|
| 224 |
+
]
|
| 225 |
+
self.blocks = nn.ModuleList(blocks)
|
| 226 |
+
|
| 227 |
+
def TransShape(self,x,head_channels = 512,up=0):
|
| 228 |
+
B, n_patch, hidden = x.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
|
| 229 |
+
|
| 230 |
+
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
|
| 231 |
+
x = x.permute(0, 2, 1)
|
| 232 |
+
x = x.contiguous().view(B, hidden, h, w)
|
| 233 |
+
if up==0:
|
| 234 |
+
x = self.conv_feature1(x)
|
| 235 |
+
elif up==1:
|
| 236 |
+
x = self.conv_feature2(x)
|
| 237 |
+
x = self.up2(x)
|
| 238 |
+
elif up==2:
|
| 239 |
+
x = self.conv_feature3(x)
|
| 240 |
+
x = self.up3(x)
|
| 241 |
+
elif up==3:
|
| 242 |
+
x = self.conv_feature4(x)
|
| 243 |
+
x = self.up4(x)
|
| 244 |
+
return x
|
| 245 |
+
|
| 246 |
+
def forward(self, hidden_states, features=None):
|
| 247 |
+
B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
|
| 248 |
+
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
|
| 249 |
+
x = hidden_states.permute(0, 2, 1)
|
| 250 |
+
x = x.contiguous().view(B, hidden, h, w)
|
| 251 |
+
x = self.conv_more(x)
|
| 252 |
+
skip_channels=[512,256,128,64]
|
| 253 |
+
for i, decoder_block in enumerate(self.blocks):
|
| 254 |
+
if features is not None:
|
| 255 |
+
skip = self.TransShape(features[i],head_channels=skip_channels[i],up=i)
|
| 256 |
+
else:
|
| 257 |
+
skip = None
|
| 258 |
+
x = decoder_block(x, skip=skip)
|
| 259 |
+
return x
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class MLAHead(nn.Module):
|
| 263 |
+
def __init__(self, mla_channels=256, mlahead_channels=128, norm_cfg=None):
|
| 264 |
+
super(MLAHead, self).__init__()
|
| 265 |
+
self.head2 = nn.Sequential(nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False),
|
| 266 |
+
nn.BatchNorm2d(mlahead_channels), nn.ReLU(),
|
| 267 |
+
nn.Conv2d(
|
| 268 |
+
mlahead_channels, mlahead_channels, 3, padding=1, bias=False),
|
| 269 |
+
nn.BatchNorm2d(mlahead_channels), nn.ReLU())
|
| 270 |
+
self.head3 = nn.Sequential(nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False),
|
| 271 |
+
nn.BatchNorm2d(mlahead_channels), nn.ReLU(),
|
| 272 |
+
nn.Conv2d(
|
| 273 |
+
mlahead_channels, mlahead_channels, 3, padding=1, bias=False),
|
| 274 |
+
nn.BatchNorm2d(mlahead_channels), nn.ReLU())
|
| 275 |
+
self.head4 = nn.Sequential(nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False),
|
| 276 |
+
nn.BatchNorm2d(mlahead_channels), nn.ReLU(),
|
| 277 |
+
nn.Conv2d(
|
| 278 |
+
mlahead_channels, mlahead_channels, 3, padding=1, bias=False),
|
| 279 |
+
nn.BatchNorm2d(mlahead_channels), nn.ReLU())
|
| 280 |
+
self.head5 = nn.Sequential(nn.Conv2d(mla_channels, mlahead_channels, 3, padding=1, bias=False),
|
| 281 |
+
nn.BatchNorm2d(mlahead_channels), nn.ReLU(),
|
| 282 |
+
nn.Conv2d(
|
| 283 |
+
mlahead_channels, mlahead_channels, 3, padding=1, bias=False),
|
| 284 |
+
nn.BatchNorm2d(mlahead_channels), nn.ReLU())
|
| 285 |
+
|
| 286 |
+
def forward(self, mla_p2, mla_p3, mla_p4, mla_p5):
|
| 287 |
+
head2 = F.interpolate(self.head2(
|
| 288 |
+
mla_p2), (4*mla_p2.shape[-2],4*mla_p2.shape[-1]), mode='bilinear', align_corners=True)
|
| 289 |
+
head3 = F.interpolate(self.head3(
|
| 290 |
+
mla_p3), (4*mla_p3.shape[-2],4*mla_p3.shape[-1]), mode='bilinear', align_corners=True)
|
| 291 |
+
head4 = F.interpolate(self.head4(
|
| 292 |
+
mla_p4), (4*mla_p4.shape[-2],4*mla_p4.shape[-1]), mode='bilinear', align_corners=True)
|
| 293 |
+
head5 = F.interpolate(self.head5(
|
| 294 |
+
mla_p5), (4*mla_p5.shape[-2],4*mla_p5.shape[-1]), mode='bilinear', align_corners=True)
|
| 295 |
+
return torch.cat([head2, head3, head4, head5], dim=1)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class VIT_MLAHead(nn.Module):
|
| 299 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
def __init__(self, img_size=768, mla_channels=256, mlahead_channels=128, num_classes=6,
|
| 303 |
+
norm_layer=nn.BatchNorm2d, norm_cfg=None, **kwargs):
|
| 304 |
+
super(VIT_MLAHead, self).__init__(**kwargs)
|
| 305 |
+
self.img_size = img_size
|
| 306 |
+
self.norm_cfg = norm_cfg
|
| 307 |
+
self.mla_channels = mla_channels
|
| 308 |
+
self.BatchNorm = norm_layer
|
| 309 |
+
self.mlahead_channels = mlahead_channels
|
| 310 |
+
self.num_classes = num_classes
|
| 311 |
+
self.mlahead = MLAHead(mla_channels=self.mla_channels,
|
| 312 |
+
mlahead_channels=self.mlahead_channels, norm_cfg=self.norm_cfg)
|
| 313 |
+
self.cls = nn.Conv2d(4 * self.mlahead_channels,
|
| 314 |
+
self.num_classes, 3, padding=1)
|
| 315 |
+
|
| 316 |
+
def forward(self, x1, x2, x3, x4, h=14, w=14, target_h=None, target_w=None):
|
| 317 |
+
B, n_patch, hidden = x1.size()
|
| 318 |
+
if h == w:
|
| 319 |
+
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
|
| 320 |
+
|
| 321 |
+
# Reshape all feature maps
|
| 322 |
+
x1 = x1.permute(0, 2, 1).contiguous().view(B, hidden, h, w)
|
| 323 |
+
x2 = x2.permute(0, 2, 1).contiguous().view(B, hidden, h, w)
|
| 324 |
+
x3 = x3.permute(0, 2, 1).contiguous().view(B, hidden, h, w)
|
| 325 |
+
x4 = x4.permute(0, 2, 1).contiguous().view(B, hidden, h, w)
|
| 326 |
+
|
| 327 |
+
# Apply MLA head
|
| 328 |
+
x = self.mlahead(x1, x2, x3, x4)
|
| 329 |
+
x = self.cls(x)
|
| 330 |
+
|
| 331 |
+
# Calculate target size - if original image wasn't patch-size divisible
|
| 332 |
+
patch_size = 16 # assuming patch size of 16
|
| 333 |
+
if target_h is not None and target_w is not None:
|
| 334 |
+
target_size = (target_h, target_w)
|
| 335 |
+
else:
|
| 336 |
+
target_size = (h * patch_size, w * patch_size)
|
| 337 |
+
|
| 338 |
+
# Interpolate to target size
|
| 339 |
+
x = F.interpolate(x, size=target_size, mode='bilinear', align_corners=True)
|
| 340 |
+
return x
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def mae_vit_small_patch16(**kwargs):
|
| 344 |
+
model = VisionTransformer(
|
| 345 |
+
patch_size=16, embed_dim=768, depth=6, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 346 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 347 |
+
# Replace with flexible patch embedding
|
| 348 |
+
model.patch_embed = FlexiblePatchEmbed(
|
| 349 |
+
img_size=kwargs.get('img_size', 224),
|
| 350 |
+
patch_size=16,
|
| 351 |
+
in_chans=kwargs.get('in_chans', 3),
|
| 352 |
+
embed_dim=768
|
| 353 |
+
)
|
| 354 |
+
return model
|
| 355 |
+
|
| 356 |
+
def vit_base_patch16(**kwargs):
|
| 357 |
+
model = VisionTransformer(
|
| 358 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 359 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 360 |
+
# Replace with flexible patch embedding
|
| 361 |
+
model.patch_embed = FlexiblePatchEmbed(
|
| 362 |
+
img_size=kwargs.get('img_size', 224),
|
| 363 |
+
patch_size=16,
|
| 364 |
+
in_chans=kwargs.get('in_chans', 3),
|
| 365 |
+
embed_dim=768
|
| 366 |
+
)
|
| 367 |
+
return model
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def vit_large_patch16(**kwargs):
|
| 371 |
+
model = VisionTransformer(
|
| 372 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 373 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 374 |
+
# Replace with flexible patch embedding
|
| 375 |
+
model.patch_embed = FlexiblePatchEmbed(
|
| 376 |
+
img_size=kwargs.get('img_size', 224),
|
| 377 |
+
patch_size=16,
|
| 378 |
+
in_chans=kwargs.get('in_chans', 3),
|
| 379 |
+
embed_dim=1024
|
| 380 |
+
)
|
| 381 |
+
return model
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def vit_huge_patch14(**kwargs):
|
| 385 |
+
model = VisionTransformer(
|
| 386 |
+
patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 387 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 388 |
+
# Replace with flexible patch embedding
|
| 389 |
+
model.patch_embed = FlexiblePatchEmbed(
|
| 390 |
+
img_size=kwargs.get('img_size', 224),
|
| 391 |
+
patch_size=14,
|
| 392 |
+
in_chans=kwargs.get('in_chans', 3),
|
| 393 |
+
embed_dim=1280
|
| 394 |
+
)
|
| 395 |
+
return model
|
| 396 |
+
|
| 397 |
+
|
models_Fault.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 9 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
from functools import partial
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import timm.models.vision_transformer
|
| 19 |
+
import numpy as np
|
| 20 |
+
from util.msssim import MSSSIM
|
| 21 |
+
from util.pos_embed import get_2d_sincos_pos_embed
|
| 22 |
+
from util.variable_pos_embed import interpolate_pos_embed_variable
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class FlexiblePatchEmbed(nn.Module):
|
| 26 |
+
""" 2D Image to Patch Embedding that handles variable input sizes """
|
| 27 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, bias=True):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.img_size = img_size
|
| 30 |
+
self.patch_size = patch_size
|
| 31 |
+
self.in_chans = in_chans
|
| 32 |
+
self.embed_dim = embed_dim
|
| 33 |
+
|
| 34 |
+
self.num_patches = (img_size // patch_size) ** 2 # default number of patches
|
| 35 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
B, C, H, W = x.shape
|
| 39 |
+
# Calculate number of patches dynamically
|
| 40 |
+
self.num_patches = (H // self.patch_size) * (W // self.patch_size)
|
| 41 |
+
x = self.proj(x).flatten(2).transpose(1, 2) # BCHW -> BNC
|
| 42 |
+
return x
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
|
| 46 |
+
""" Vision Transformer with support for global average pooling
|
| 47 |
+
"""
|
| 48 |
+
def __init__(self, global_pool=False,**kwargs):
|
| 49 |
+
super(VisionTransformer, self).__init__(**kwargs)
|
| 50 |
+
|
| 51 |
+
self.global_pool = global_pool
|
| 52 |
+
self.decoder = DecoderCup(in_channels=[self.embed_dim,256,128,64])
|
| 53 |
+
|
| 54 |
+
self.segmentation_head = SegmentationHead(
|
| 55 |
+
in_channels=64,
|
| 56 |
+
out_channels=self.num_classes,
|
| 57 |
+
kernel_size=1
|
| 58 |
+
)
|
| 59 |
+
if self.global_pool:
|
| 60 |
+
norm_layer = kwargs['norm_layer']
|
| 61 |
+
embed_dim = kwargs['embed_dim']
|
| 62 |
+
self.fc_norm = norm_layer(embed_dim)
|
| 63 |
+
del self.norm # remove the original norm
|
| 64 |
+
|
| 65 |
+
def interpolate_pos_encoding(self, x, h, w):
|
| 66 |
+
"""
|
| 67 |
+
Interpolate positional embeddings for arbitrary input sizes
|
| 68 |
+
"""
|
| 69 |
+
npatch = x.shape[1] - 1 # subtract 1 for cls token
|
| 70 |
+
N = self.pos_embed.shape[1] - 1 # original number of patches
|
| 71 |
+
|
| 72 |
+
if npatch == N and h == w:
|
| 73 |
+
return self.pos_embed
|
| 74 |
+
|
| 75 |
+
# Use the new variable position embedding utility
|
| 76 |
+
return interpolate_pos_embed_variable(self.pos_embed, h, w, cls_token=True)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def generate_mask(self,input_tensor, ratio):
|
| 80 |
+
mask = torch.zeros_like(input_tensor)
|
| 81 |
+
indices = torch.randperm(mask.size(3)//16)[:int(mask.size(3)//16 * ratio)]
|
| 82 |
+
sorted_indices = torch.sort(indices)[0]
|
| 83 |
+
for i in range(0, len(sorted_indices)):
|
| 84 |
+
mask[:, :, :, sorted_indices[i]*16:(sorted_indices[i]+1)*16] = 1
|
| 85 |
+
return mask
|
| 86 |
+
|
| 87 |
+
def forward_features(self, x):
|
| 88 |
+
B,C,H,W = x.shape
|
| 89 |
+
|
| 90 |
+
# Handle padding for non-16-divisible images
|
| 91 |
+
patch_size = self.patch_embed.patch_size
|
| 92 |
+
pad_h = (patch_size - H % patch_size) % patch_size
|
| 93 |
+
pad_w = (patch_size - W % patch_size) % patch_size
|
| 94 |
+
|
| 95 |
+
if pad_h > 0 or pad_w > 0:
|
| 96 |
+
x = F.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
|
| 97 |
+
H_padded, W_padded = H + pad_h, W + pad_w
|
| 98 |
+
else:
|
| 99 |
+
H_padded, W_padded = H, W
|
| 100 |
+
|
| 101 |
+
img = x
|
| 102 |
+
x = self.patch_embed(x)
|
| 103 |
+
|
| 104 |
+
_H, _W = H_padded // patch_size, W_padded // patch_size
|
| 105 |
+
|
| 106 |
+
# Add class token
|
| 107 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
| 108 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 109 |
+
|
| 110 |
+
# Add interpolated positional embeddings
|
| 111 |
+
pos_embed = self.interpolate_pos_encoding(x, _H, _W)
|
| 112 |
+
x = x + pos_embed
|
| 113 |
+
x = self.pos_drop(x)
|
| 114 |
+
|
| 115 |
+
for blk in self.blocks:
|
| 116 |
+
x = blk(x)
|
| 117 |
+
x = self.norm(x)
|
| 118 |
+
|
| 119 |
+
x = self.decoder(x[:, 1:, :], img)
|
| 120 |
+
x = self.segmentation_head(x)
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
def forward(self, x):
|
| 124 |
+
|
| 125 |
+
x = self.forward_features(x)
|
| 126 |
+
|
| 127 |
+
return x
|
| 128 |
+
|
| 129 |
+
def inference(self, x):
|
| 130 |
+
x = self.forward_features(x)
|
| 131 |
+
x = F.softmax(x, dim=1)
|
| 132 |
+
|
| 133 |
+
return x
|
| 134 |
+
|
| 135 |
+
class Conv2dReLU(nn.Sequential):
|
| 136 |
+
def __init__(
|
| 137 |
+
self,
|
| 138 |
+
in_channels,
|
| 139 |
+
out_channels,
|
| 140 |
+
kernel_size,
|
| 141 |
+
padding=0,
|
| 142 |
+
stride=1,
|
| 143 |
+
use_batchnorm=True,
|
| 144 |
+
):
|
| 145 |
+
conv = nn.Conv2d(
|
| 146 |
+
in_channels,
|
| 147 |
+
out_channels,
|
| 148 |
+
kernel_size,
|
| 149 |
+
stride=stride,
|
| 150 |
+
padding=padding,
|
| 151 |
+
bias=not (use_batchnorm),
|
| 152 |
+
)
|
| 153 |
+
relu = nn.ReLU(inplace=True)
|
| 154 |
+
|
| 155 |
+
bn = nn.BatchNorm2d(out_channels)
|
| 156 |
+
|
| 157 |
+
super(Conv2dReLU, self).__init__(conv, bn, relu)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class DecoderBlock(nn.Module):
|
| 161 |
+
def __init__(
|
| 162 |
+
self,
|
| 163 |
+
in_channels,
|
| 164 |
+
out_channels,
|
| 165 |
+
skip_channels=0,
|
| 166 |
+
use_batchnorm=True,
|
| 167 |
+
):
|
| 168 |
+
super().__init__()
|
| 169 |
+
self.conv1 = Conv2dReLU(
|
| 170 |
+
in_channels + skip_channels,
|
| 171 |
+
out_channels,
|
| 172 |
+
kernel_size=3,
|
| 173 |
+
padding=1,
|
| 174 |
+
use_batchnorm=use_batchnorm,
|
| 175 |
+
)
|
| 176 |
+
self.conv2 = Conv2dReLU(
|
| 177 |
+
out_channels,
|
| 178 |
+
out_channels,
|
| 179 |
+
kernel_size=3,
|
| 180 |
+
padding=1,
|
| 181 |
+
use_batchnorm=use_batchnorm,
|
| 182 |
+
)
|
| 183 |
+
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
|
| 184 |
+
|
| 185 |
+
def forward(self, x, skip=None):
|
| 186 |
+
x = self.up(x)
|
| 187 |
+
if skip is not None:
|
| 188 |
+
x = torch.cat([x, skip], dim=1)
|
| 189 |
+
x = self.conv1(x)
|
| 190 |
+
x = self.conv2(x)
|
| 191 |
+
return x
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class SegmentationHead(nn.Sequential):
|
| 195 |
+
|
| 196 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, upsampling=1):
|
| 197 |
+
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=0)
|
| 198 |
+
upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
|
| 199 |
+
super().__init__(conv2d, upsampling)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class DecoderCup(nn.Module):
|
| 203 |
+
def __init__(self,in_channels=[1024,256,128,64]):
|
| 204 |
+
super().__init__()
|
| 205 |
+
head_channels = 512
|
| 206 |
+
self.conv_more = Conv2dReLU(
|
| 207 |
+
1,
|
| 208 |
+
32,
|
| 209 |
+
kernel_size=3,
|
| 210 |
+
padding=1,
|
| 211 |
+
use_batchnorm=True,
|
| 212 |
+
)
|
| 213 |
+
skip_channels=[0,0,0,32]
|
| 214 |
+
out_channels=[256,128,64,64]
|
| 215 |
+
blocks = [
|
| 216 |
+
DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
|
| 217 |
+
]
|
| 218 |
+
self.blocks = nn.ModuleList(blocks)
|
| 219 |
+
|
| 220 |
+
def forward(self, hidden_states, img, features=None):
|
| 221 |
+
B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden)
|
| 222 |
+
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
|
| 223 |
+
x = hidden_states.permute(0, 2, 1)
|
| 224 |
+
x = x.contiguous().view(B, hidden, h, w)
|
| 225 |
+
skip_channels=[None,None,None,self.conv_more(img)]
|
| 226 |
+
for i, decoder_block in enumerate(self.blocks):
|
| 227 |
+
x = decoder_block(x, skip=skip_channels[i])
|
| 228 |
+
return x
|
| 229 |
+
|
| 230 |
+
def forward_loss(imgs, pred):
|
| 231 |
+
"""
|
| 232 |
+
imgs: [N, 3, H, W]
|
| 233 |
+
pred: [N, L, p*p*3]
|
| 234 |
+
mask: [N, L], 0 is keep, 1 is remove,
|
| 235 |
+
"""
|
| 236 |
+
loss1f = torch.nn.MSELoss()
|
| 237 |
+
loss1 = loss1f(imgs, pred)
|
| 238 |
+
loss2f = MSSSIM()
|
| 239 |
+
loss2 = loss2f(imgs, pred)
|
| 240 |
+
a = 0.5
|
| 241 |
+
loss = (1-a)*loss1+a*loss2
|
| 242 |
+
return loss
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def weighted_cross_entropy(pred, target):
|
| 246 |
+
"""
|
| 247 |
+
Compute the weighted cross entropy loss.
|
| 248 |
+
NEED VERIFICATION
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
# Function to compute weighted cross entropy loss
|
| 252 |
+
# target: [batch, channel, s, s]
|
| 253 |
+
# pred: [batch, channel, s, s]
|
| 254 |
+
|
| 255 |
+
#print('pred shape ', pred.shape)
|
| 256 |
+
#print('target shape ', target.shape)
|
| 257 |
+
#print('--------------')
|
| 258 |
+
#print('sums of pred', torch.sum(pred))
|
| 259 |
+
#print('sums of target', torch.sum(target))
|
| 260 |
+
# beta is the fraction of non-fault pixels in the target (i.e the zeroes in the target)
|
| 261 |
+
beta = torch.mean(target) # fraction of fault pixels
|
| 262 |
+
beta = 1 - beta # fraction of non-fault pixels
|
| 263 |
+
beta = torch.clamp(beta, min=0.01, max=0.99) # avoid division by zero
|
| 264 |
+
|
| 265 |
+
#print('beta', beta)
|
| 266 |
+
|
| 267 |
+
# Compute the weighted cross entropy loss
|
| 268 |
+
loss = -(beta * target * torch.log(pred + 1e-8) + (1-beta) * (1 - target) * torch.log(1 - pred + 1e-8))
|
| 269 |
+
return torch.mean(loss)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def mae_vit_small_patch16(**kwargs):
|
| 273 |
+
model = VisionTransformer(
|
| 274 |
+
patch_size=16, embed_dim=768, depth=6, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 275 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 276 |
+
# Replace with flexible patch embedding
|
| 277 |
+
model.patch_embed = FlexiblePatchEmbed(
|
| 278 |
+
img_size=kwargs.get('img_size', 224),
|
| 279 |
+
patch_size=16,
|
| 280 |
+
in_chans=kwargs.get('in_chans', 3),
|
| 281 |
+
embed_dim=768
|
| 282 |
+
)
|
| 283 |
+
return model
|
| 284 |
+
|
| 285 |
+
def vit_base_patch16(**kwargs):
|
| 286 |
+
model = VisionTransformer(
|
| 287 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 288 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 289 |
+
# Replace with flexible patch embedding
|
| 290 |
+
model.patch_embed = FlexiblePatchEmbed(
|
| 291 |
+
img_size=kwargs.get('img_size', 224),
|
| 292 |
+
patch_size=16,
|
| 293 |
+
in_chans=kwargs.get('in_chans', 3),
|
| 294 |
+
embed_dim=768
|
| 295 |
+
)
|
| 296 |
+
return model
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def vit_large_patch16(**kwargs):
|
| 300 |
+
model = VisionTransformer(
|
| 301 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 302 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 303 |
+
# Replace with flexible patch embedding
|
| 304 |
+
model.patch_embed = FlexiblePatchEmbed(
|
| 305 |
+
img_size=kwargs.get('img_size', 224),
|
| 306 |
+
patch_size=16,
|
| 307 |
+
in_chans=kwargs.get('in_chans', 3),
|
| 308 |
+
embed_dim=1024
|
| 309 |
+
)
|
| 310 |
+
return model
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def vit_huge_patch14(**kwargs):
|
| 314 |
+
model = VisionTransformer(
|
| 315 |
+
patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
| 316 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 317 |
+
# Replace with flexible patch embedding
|
| 318 |
+
model.patch_embed = FlexiblePatchEmbed(
|
| 319 |
+
img_size=kwargs.get('img_size', 224),
|
| 320 |
+
patch_size=14,
|
| 321 |
+
in_chans=kwargs.get('in_chans', 3),
|
| 322 |
+
embed_dim=1280
|
| 323 |
+
)
|
| 324 |
+
return model
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
|
util/__pycache__/datasets.cpython-311.pyc
ADDED
|
Binary file (23.6 kB). View file
|
|
|
util/__pycache__/datasets.cpython-312.pyc
ADDED
|
Binary file (32.1 kB). View file
|
|
|
util/__pycache__/datasets.cpython-36.pyc
ADDED
|
Binary file (19.2 kB). View file
|
|
|
util/__pycache__/datasets.cpython-37.pyc
ADDED
|
Binary file (19.6 kB). View file
|
|
|
util/__pycache__/lars.cpython-36.pyc
ADDED
|
Binary file (1.34 kB). View file
|
|
|
util/__pycache__/lr_decay.cpython-311.pyc
ADDED
|
Binary file (2.66 kB). View file
|
|
|
util/__pycache__/lr_decay.cpython-312.pyc
ADDED
|
Binary file (2.39 kB). View file
|
|
|
util/__pycache__/lr_decay.cpython-36.pyc
ADDED
|
Binary file (1.6 kB). View file
|
|
|
util/__pycache__/lr_decay.cpython-37.pyc
ADDED
|
Binary file (1.59 kB). View file
|
|
|
util/__pycache__/lr_sched.cpython-311.pyc
ADDED
|
Binary file (1.02 kB). View file
|
|
|
util/__pycache__/lr_sched.cpython-312.pyc
ADDED
|
Binary file (1.12 kB). View file
|
|
|
util/__pycache__/lr_sched.cpython-36.pyc
ADDED
|
Binary file (595 Bytes). View file
|
|
|
util/__pycache__/lr_sched.cpython-37.pyc
ADDED
|
Binary file (599 Bytes). View file
|
|
|
util/__pycache__/metrics.cpython-36.pyc
ADDED
|
Binary file (3.83 kB). View file
|
|
|
util/__pycache__/misc.cpython-311.pyc
ADDED
|
Binary file (21.4 kB). View file
|
|
|
util/__pycache__/misc.cpython-312.pyc
ADDED
|
Binary file (19.4 kB). View file
|
|
|
util/__pycache__/misc.cpython-36.pyc
ADDED
|
Binary file (10.8 kB). View file
|
|
|
util/__pycache__/misc.cpython-37.pyc
ADDED
|
Binary file (10.8 kB). View file
|
|
|
util/__pycache__/msssim.cpython-311.pyc
ADDED
|
Binary file (8.9 kB). View file
|
|
|
util/__pycache__/msssim.cpython-312.pyc
ADDED
|
Binary file (7.84 kB). View file
|
|
|
util/__pycache__/msssim.cpython-36.pyc
ADDED
|
Binary file (4.51 kB). View file
|
|
|
util/__pycache__/msssim.cpython-37.pyc
ADDED
|
Binary file (4.49 kB). View file
|
|
|
util/__pycache__/pos_embed.cpython-311.pyc
ADDED
|
Binary file (4.35 kB). View file
|
|
|
util/__pycache__/pos_embed.cpython-312.pyc
ADDED
|
Binary file (4.14 kB). View file
|
|
|
util/__pycache__/pos_embed.cpython-36.pyc
ADDED
|
Binary file (2.43 kB). View file
|
|
|
util/__pycache__/pos_embed.cpython-37.pyc
ADDED
|
Binary file (2.42 kB). View file
|
|
|
util/__pycache__/size_aware_batching.cpython-312.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
util/__pycache__/skeletonize.cpython-312.pyc
ADDED
|
Binary file (35.5 kB). View file
|
|
|
util/__pycache__/tools.cpython-311.pyc
ADDED
|
Binary file (7.76 kB). View file
|
|
|
util/__pycache__/tools.cpython-312.pyc
ADDED
|
Binary file (6.81 kB). View file
|
|
|
util/__pycache__/tools.cpython-36.pyc
ADDED
|
Binary file (4.25 kB). View file
|
|
|
util/__pycache__/tools.cpython-37.pyc
ADDED
|
Binary file (4.26 kB). View file
|
|
|
util/__pycache__/variable_pos_embed.cpython-312.pyc
ADDED
|
Binary file (5.42 kB). View file
|
|
|
util/crop.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from torchvision import transforms
|
| 12 |
+
from torchvision.transforms import functional as F
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RandomResizedCrop(transforms.RandomResizedCrop):
|
| 16 |
+
"""
|
| 17 |
+
RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
|
| 18 |
+
This may lead to results different with torchvision's version.
|
| 19 |
+
Following BYOL's TF code:
|
| 20 |
+
https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
|
| 21 |
+
"""
|
| 22 |
+
@staticmethod
|
| 23 |
+
def get_params(img, scale, ratio):
|
| 24 |
+
width, height = F._get_image_size(img)
|
| 25 |
+
area = height * width
|
| 26 |
+
|
| 27 |
+
target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
|
| 28 |
+
log_ratio = torch.log(torch.tensor(ratio))
|
| 29 |
+
aspect_ratio = torch.exp(
|
| 30 |
+
torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
|
| 31 |
+
).item()
|
| 32 |
+
|
| 33 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
| 34 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
| 35 |
+
|
| 36 |
+
w = min(w, width)
|
| 37 |
+
h = min(h, height)
|
| 38 |
+
|
| 39 |
+
i = torch.randint(0, height - h + 1, size=(1,)).item()
|
| 40 |
+
j = torch.randint(0, width - w + 1, size=(1,)).item()
|
| 41 |
+
|
| 42 |
+
return i, j, h, w
|
util/datasets.py
ADDED
|
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 9 |
+
# --------------------------------------------------------
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import PIL
|
| 13 |
+
|
| 14 |
+
import os, random, glob
|
| 15 |
+
import numpy as np
|
| 16 |
+
import torch
|
| 17 |
+
import torch.utils.data as data
|
| 18 |
+
import torchvision.transforms as transforms
|
| 19 |
+
from os.path import isfile, join
|
| 20 |
+
import segyio
|
| 21 |
+
from itertools import permutations
|
| 22 |
+
|
| 23 |
+
random.seed(42)
|
| 24 |
+
|
| 25 |
+
from torchvision import datasets, transforms
|
| 26 |
+
|
| 27 |
+
from timm.data import create_transform
|
| 28 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def build_dataset(is_train, args):
|
| 32 |
+
transform = build_transform(is_train, args)
|
| 33 |
+
|
| 34 |
+
root = os.path.join(args.data_path, 'train' if is_train else 'val')
|
| 35 |
+
dataset = datasets.ImageFolder(root, transform=transform)
|
| 36 |
+
|
| 37 |
+
print(dataset)
|
| 38 |
+
|
| 39 |
+
return dataset
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def build_transform(is_train, args):
|
| 43 |
+
mean = IMAGENET_DEFAULT_MEAN
|
| 44 |
+
std = IMAGENET_DEFAULT_STD
|
| 45 |
+
# train transform
|
| 46 |
+
if is_train:
|
| 47 |
+
# this should always dispatch to transforms_imagenet_train
|
| 48 |
+
transform = create_transform(
|
| 49 |
+
input_size=args.input_size,
|
| 50 |
+
is_training=True,
|
| 51 |
+
color_jitter=args.color_jitter,
|
| 52 |
+
auto_augment=args.aa,
|
| 53 |
+
interpolation='bicubic',
|
| 54 |
+
re_prob=args.reprob,
|
| 55 |
+
re_mode=args.remode,
|
| 56 |
+
re_count=args.recount,
|
| 57 |
+
mean=mean,
|
| 58 |
+
std=std,
|
| 59 |
+
)
|
| 60 |
+
return transform
|
| 61 |
+
|
| 62 |
+
# eval transform
|
| 63 |
+
t = []
|
| 64 |
+
if args.input_size <= 224:
|
| 65 |
+
crop_pct = 224 / 256
|
| 66 |
+
else:
|
| 67 |
+
crop_pct = 1.0
|
| 68 |
+
size = int(args.input_size / crop_pct)
|
| 69 |
+
t.append(
|
| 70 |
+
transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
|
| 71 |
+
)
|
| 72 |
+
t.append(transforms.CenterCrop(args.input_size))
|
| 73 |
+
|
| 74 |
+
t.append(transforms.ToTensor())
|
| 75 |
+
t.append(transforms.Normalize(mean, std))
|
| 76 |
+
return transforms.Compose(t)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
## pretrain
|
| 80 |
+
class SeismicSet(data.Dataset):
|
| 81 |
+
|
| 82 |
+
def __init__(self, path, input_size) -> None:
|
| 83 |
+
super().__init__()
|
| 84 |
+
# self.file_list = os.listdir(path)
|
| 85 |
+
# self.file_list = [os.path.join(path, f) for f in self.file_list]
|
| 86 |
+
self.get_file_list(path)
|
| 87 |
+
self.input_size = input_size
|
| 88 |
+
print(len(self.file_list))
|
| 89 |
+
|
| 90 |
+
def __len__(self) -> int:
|
| 91 |
+
return len(self.file_list)
|
| 92 |
+
# return 100000
|
| 93 |
+
|
| 94 |
+
def __getitem__(self, index):
|
| 95 |
+
d = np.fromfile(self.file_list[index], dtype=np.float32)
|
| 96 |
+
d = d.reshape(1, self.input_size, self.input_size)
|
| 97 |
+
d = (d - d.mean()) / (d.std()+1e-6)
|
| 98 |
+
|
| 99 |
+
# return to_transforms(d, self.input_size)
|
| 100 |
+
return d,torch.tensor([1])
|
| 101 |
+
|
| 102 |
+
def get_file_list(self, path):
|
| 103 |
+
dirs = [os.path.join(path, f) for f in os.listdir(path)]
|
| 104 |
+
self.file_list = dirs
|
| 105 |
+
|
| 106 |
+
# for ds in dirs:
|
| 107 |
+
# if os.path.isdir(ds):
|
| 108 |
+
# self.file_list += [os.path.join(ds, f) for f in os.listdir(ds)]
|
| 109 |
+
|
| 110 |
+
return random.shuffle(self.file_list)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def to_transforms(d, input_size):
|
| 114 |
+
t = transforms.Compose([
|
| 115 |
+
transforms.RandomResizedCrop(input_size,
|
| 116 |
+
scale=(0.2, 1.0),
|
| 117 |
+
interpolation=3), # 3 is bicubic
|
| 118 |
+
transforms.RandomHorizontalFlip(),
|
| 119 |
+
transforms.ToTensor()
|
| 120 |
+
])
|
| 121 |
+
|
| 122 |
+
return t(d)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
### fintune
|
| 127 |
+
class FacesSet(data.Dataset):
|
| 128 |
+
# folder/train/data/**.dat, folder/train/label/**.dat
|
| 129 |
+
# folder/test/data/**.dat, folder/test/label/**.dat
|
| 130 |
+
def __init__(self,
|
| 131 |
+
folder,
|
| 132 |
+
shape=[768, 768],
|
| 133 |
+
is_train=True) -> None:
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.shape = shape
|
| 136 |
+
|
| 137 |
+
# self.data_list = sorted(glob.glob(folder + 'seismic/*.dat'))
|
| 138 |
+
self.data_list = [folder +'seismic/'+ str(f)+'.dat' for f in range(117)]
|
| 139 |
+
|
| 140 |
+
n = len(self.data_list)
|
| 141 |
+
if is_train:
|
| 142 |
+
self.data_list = self.data_list[:100]
|
| 143 |
+
elif not is_train:
|
| 144 |
+
self.data_list = self.data_list[100:]
|
| 145 |
+
self.label_list = [
|
| 146 |
+
f.replace('/seismic/', '/label/') for f in self.data_list
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
def __getitem__(self, index):
|
| 150 |
+
d = np.fromfile(self.data_list[index], np.float32)
|
| 151 |
+
d = d.reshape([1] + self.shape)
|
| 152 |
+
l = np.fromfile(self.label_list[index], np.float32).reshape(self.shape)-1
|
| 153 |
+
l = l.astype(int)
|
| 154 |
+
return torch.tensor(d), torch.tensor(l)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def __len__(self):
|
| 158 |
+
return len(self.data_list)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class SaltSet(data.Dataset):
|
| 163 |
+
|
| 164 |
+
def __init__(self,
|
| 165 |
+
folder,
|
| 166 |
+
shape=[224, 224],
|
| 167 |
+
is_train=True) -> None:
|
| 168 |
+
super().__init__()
|
| 169 |
+
self.shape = shape
|
| 170 |
+
self.data_list = [folder +'seismic/'+ str(f)+'.dat' for f in range(4000)]
|
| 171 |
+
n = len(self.data_list)
|
| 172 |
+
if is_train:
|
| 173 |
+
self.data_list = self.data_list[:3500]
|
| 174 |
+
elif not is_train:
|
| 175 |
+
self.data_list = self.data_list[3500:]
|
| 176 |
+
self.label_list = [
|
| 177 |
+
f.replace('/seismic/', '/label/') for f in self.data_list
|
| 178 |
+
]
|
| 179 |
+
|
| 180 |
+
def __getitem__(self, index):
|
| 181 |
+
d = np.fromfile(self.data_list[index], np.float32)
|
| 182 |
+
d = d.reshape([1] + self.shape)
|
| 183 |
+
l = np.fromfile(self.label_list[index], np.float32).reshape(self.shape)
|
| 184 |
+
l = l.astype(int)
|
| 185 |
+
return torch.tensor(d), torch.tensor(l)
|
| 186 |
+
def __len__(self):
|
| 187 |
+
return len(self.data_list)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class InterpolationSet(data.Dataset):
|
| 191 |
+
# folder/train/data/**.dat, folder/train/label/**.dat
|
| 192 |
+
# folder/test/data/**.dat, folder/test/label/**.dat
|
| 193 |
+
def __init__(self,
|
| 194 |
+
folder,
|
| 195 |
+
shape=[224, 224],
|
| 196 |
+
is_train=True) -> None:
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.shape = shape
|
| 199 |
+
self.data_list = [folder + str(f)+'.dat' for f in range(6000)]
|
| 200 |
+
n = len(self.data_list)
|
| 201 |
+
if is_train:
|
| 202 |
+
self.data_list = self.data_list
|
| 203 |
+
elif not is_train:
|
| 204 |
+
self.data_list = [folder+'U'+ + str(f)+'.dat' for f in range(2000,4000)]
|
| 205 |
+
self.label_list = self.data_list
|
| 206 |
+
|
| 207 |
+
def __getitem__(self, index):
|
| 208 |
+
d = np.fromfile(self.data_list[index], np.float32)
|
| 209 |
+
d = d.reshape([1] + self.shape)
|
| 210 |
+
return torch.tensor(d), torch.tensor(d)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def __len__(self):
|
| 214 |
+
return len(self.data_list)
|
| 215 |
+
# return 10000
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class DenoiseSet(data.Dataset):
|
| 220 |
+
def __init__(self,
|
| 221 |
+
folder,
|
| 222 |
+
shape=[224, 224],
|
| 223 |
+
is_train=True) -> None:
|
| 224 |
+
super().__init__()
|
| 225 |
+
self.shape = shape
|
| 226 |
+
self.data_list = [folder+'seismic/'+ str(f)+'.dat' for f in range(2000)]
|
| 227 |
+
n = len(self.data_list)
|
| 228 |
+
if is_train:
|
| 229 |
+
self.data_list = self.data_list
|
| 230 |
+
self.label_list = [f.replace('/seismic/', '/label/') for f in self.data_list]
|
| 231 |
+
elif not is_train:
|
| 232 |
+
self.data_list = [folder+'field/'+ str(f)+'.dat' for f in range(4000)]
|
| 233 |
+
self.label_list = self.data_list
|
| 234 |
+
|
| 235 |
+
def __getitem__(self, index):
|
| 236 |
+
d = np.fromfile(self.data_list[index], np.float32)
|
| 237 |
+
d = d.reshape([1] + self.shape)
|
| 238 |
+
# d = (d - d.mean())/d.std()
|
| 239 |
+
l = np.fromfile(self.label_list[index], np.float32)
|
| 240 |
+
l = l.reshape([1] + self.shape)
|
| 241 |
+
# l = (l - d.mean())/l.std()
|
| 242 |
+
return torch.tensor(d), torch.tensor(l)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def __len__(self):
|
| 246 |
+
return len(self.data_list)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class ReflectSet(data.Dataset):
|
| 250 |
+
# folder/train/data/**.dat, folder/train/label/**.dat
|
| 251 |
+
# folder/test/data/**.dat, folder/test/label/**.dat
|
| 252 |
+
def __init__(self,
|
| 253 |
+
folder,
|
| 254 |
+
shape=[224, 224],
|
| 255 |
+
is_train=True) -> None:
|
| 256 |
+
super().__init__()
|
| 257 |
+
self.shape = shape
|
| 258 |
+
self.data_list = [folder+'seismic/'+ str(f)+'.dat' for f in range(2200)]
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
n = len(self.data_list)
|
| 263 |
+
if is_train:
|
| 264 |
+
self.data_list = self.data_list
|
| 265 |
+
self.label_list = [
|
| 266 |
+
f.replace('/seismic/', '/label/') for f in self.data_list
|
| 267 |
+
]
|
| 268 |
+
elif not is_train:
|
| 269 |
+
self.data_list = [folder+'SEAMseismic/'+ str(f)+'.dat' for f in range(4000)]
|
| 270 |
+
self.label_list = [
|
| 271 |
+
f.replace('/SEAMseismic/', '/SEAMreflect/') for f in self.data_list
|
| 272 |
+
]
|
| 273 |
+
|
| 274 |
+
def __getitem__(self, index):
|
| 275 |
+
d = np.fromfile(self.data_list[index], np.float32)
|
| 276 |
+
d = d- d.mean()
|
| 277 |
+
d = d/(d.std()+1e-6)
|
| 278 |
+
d = d.reshape([1] + self.shape)
|
| 279 |
+
l = np.fromfile(self.label_list[index], np.float32)
|
| 280 |
+
l = l-l.mean()
|
| 281 |
+
l = l/(l.std()+1e-6)
|
| 282 |
+
l = l.reshape([1] + self.shape)
|
| 283 |
+
return torch.tensor(d), torch.tensor(l)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def __len__(self):
|
| 287 |
+
return len(self.data_list)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class ThebeSet(data.Dataset):
|
| 291 |
+
def __init__(self, folder, shape=[224, 224], mode ='train') -> None:
|
| 292 |
+
super().__init__()
|
| 293 |
+
|
| 294 |
+
self.folder = folder
|
| 295 |
+
if not os.path.exists(folder):
|
| 296 |
+
raise FileNotFoundError(f"The folder {folder} does not exist.")
|
| 297 |
+
|
| 298 |
+
self.num_files = len(os.listdir(join(folder, 'fault')))
|
| 299 |
+
self.shape = shape
|
| 300 |
+
self.fault_list = [folder + '/fault/{i}.npy'.format(i=i) for i in range(1, self.num_files + 1)]
|
| 301 |
+
self.seis_list = [folder + '/seis/{i}.npy'.format(i=i) for i in range(1, self.num_files + 1)]
|
| 302 |
+
|
| 303 |
+
self.train_size = int(0.75 * self.num_files)
|
| 304 |
+
self.val_size = int(0.15 * self.num_files)
|
| 305 |
+
self.test_size = self.num_files - self.train_size - self.val_size
|
| 306 |
+
|
| 307 |
+
self.train_index = self.train_size
|
| 308 |
+
self.val_index = self.train_index + self.val_size
|
| 309 |
+
|
| 310 |
+
if mode == 'train':
|
| 311 |
+
self.fault_list = self.fault_list[:self.train_index]
|
| 312 |
+
self.seis_list = self.seis_list[:self.train_index]
|
| 313 |
+
elif mode == 'val':
|
| 314 |
+
self.fault_list = self.fault_list[self.train_index:self.val_index]
|
| 315 |
+
self.seis_list = self.seis_list[self.train_index:self.val_index]
|
| 316 |
+
elif mode == 'test':
|
| 317 |
+
self.fault_list = self.fault_list[self.val_index:]
|
| 318 |
+
self.seis_list = self.seis_list[self.val_index:]
|
| 319 |
+
else:
|
| 320 |
+
raise ValueError("Mode must be 'train', 'val', or 'test'.")
|
| 321 |
+
|
| 322 |
+
def __len__(self):
|
| 323 |
+
return len(self.fault_list)
|
| 324 |
+
|
| 325 |
+
def retrieve_patch(self, fault, seis):
|
| 326 |
+
# image will (probably) be of size [3174, 1537]
|
| 327 |
+
# return a patch of size [224, 224]
|
| 328 |
+
|
| 329 |
+
patch_height = self.shape[0]
|
| 330 |
+
patch_width = self.shape[1]
|
| 331 |
+
|
| 332 |
+
h, w = fault.shape
|
| 333 |
+
if h < patch_height or w < patch_width:
|
| 334 |
+
raise ValueError(f"Image dimensions must be at least {patch_height}x{patch_width}.")
|
| 335 |
+
|
| 336 |
+
top = random.randint(0, h - patch_height)
|
| 337 |
+
left = random.randint(0, w - patch_width)
|
| 338 |
+
|
| 339 |
+
return fault[top:top + patch_height, left:left + patch_width], seis[top:top + patch_height, left:left + patch_width]
|
| 340 |
+
|
| 341 |
+
def random_transform(self, fault, seis):
|
| 342 |
+
# Apply the same random transformations to the fault and seismic data
|
| 343 |
+
# Mirror the patch horizontally
|
| 344 |
+
if random.random() > 0.5:
|
| 345 |
+
fault = np.fliplr(fault)
|
| 346 |
+
seis = np.fliplr(seis)
|
| 347 |
+
|
| 348 |
+
# Mirror the patch vertically
|
| 349 |
+
if random.random() > 0.5:
|
| 350 |
+
fault = np.flipud(fault)
|
| 351 |
+
seis = np.flipud(seis)
|
| 352 |
+
|
| 353 |
+
return fault, seis
|
| 354 |
+
|
| 355 |
+
def __getitem__(self, index):
|
| 356 |
+
# need to see if we do normalization here (i.e. what data pre-treatement we do)
|
| 357 |
+
fault = np.load(self.fault_list[index])
|
| 358 |
+
seis = np.load(self.seis_list[index])
|
| 359 |
+
|
| 360 |
+
fault, seis = self.retrieve_patch(fault, seis)
|
| 361 |
+
fault, seis = self.random_transform(fault, seis)
|
| 362 |
+
|
| 363 |
+
seis = (seis - seis.mean()) / (seis.std() + 1e-6)
|
| 364 |
+
|
| 365 |
+
fault = torch.tensor(fault.copy(), dtype=torch.float32).unsqueeze(0)
|
| 366 |
+
seis = torch.tensor(seis.copy(), dtype=torch.float32).unsqueeze(0)
|
| 367 |
+
|
| 368 |
+
return seis, fault
|
| 369 |
+
|
| 370 |
+
class FSegSet(data.Dataset):
|
| 371 |
+
def __init__(self, folder, shape=[128, 128], mode ='train') -> None:
|
| 372 |
+
super().__init__()
|
| 373 |
+
|
| 374 |
+
self.folder = folder
|
| 375 |
+
if not os.path.exists(folder):
|
| 376 |
+
raise FileNotFoundError(f"The folder {folder} does not exist.")
|
| 377 |
+
|
| 378 |
+
self.shape = shape
|
| 379 |
+
self.mode = mode
|
| 380 |
+
|
| 381 |
+
if mode == 'train':
|
| 382 |
+
self.fault_path = join(self.folder, 'train/fault')
|
| 383 |
+
self.seis_path = join(self.folder, 'train/seis')
|
| 384 |
+
elif mode == 'val':
|
| 385 |
+
self.fault_path = join(self.folder, 'val/fault')
|
| 386 |
+
self.seis_path = join(self.folder, 'val/seis')
|
| 387 |
+
else:
|
| 388 |
+
raise ValueError("Mode must be 'train' or 'val'.")
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
self.fault_list = [join(self.fault_path, f) for f in os.listdir(self.fault_path) if f.endswith('.npy')]
|
| 392 |
+
self.seis_list = [join(self.seis_path, f) for f in os.listdir(self.seis_path) if f.endswith('.npy')]
|
| 393 |
+
|
| 394 |
+
def __len__(self):
|
| 395 |
+
return len(self.fault_list)
|
| 396 |
+
|
| 397 |
+
def __getitem__(self, index):
|
| 398 |
+
|
| 399 |
+
fault_img, seis_img = np.load(self.fault_list[index]), np.load(self.seis_list[index])
|
| 400 |
+
# These will be 128x128
|
| 401 |
+
|
| 402 |
+
seis_img = (seis_img - seis_img.mean()) / (seis_img.std() + 1e-6)
|
| 403 |
+
|
| 404 |
+
fault = torch.tensor(fault_img.copy(), dtype=torch.float32).unsqueeze(0)
|
| 405 |
+
seis = torch.tensor(seis_img.copy(), dtype=torch.float32).unsqueeze(0)
|
| 406 |
+
|
| 407 |
+
return seis, fault
|
| 408 |
+
|
| 409 |
+
class F3DFaciesSet(data.Dataset):
|
| 410 |
+
def __init__(self, folder, shape=[128, 128], mode='train', random_resize = False):
|
| 411 |
+
super().__init__()
|
| 412 |
+
|
| 413 |
+
self.folder = folder
|
| 414 |
+
if not os.path.exists(folder):
|
| 415 |
+
raise FileNotFoundError(f"The folder {folder} does not exist.")
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
self.seises = np.load(join(folder, "{}/seismic.npy".format(mode)))
|
| 419 |
+
self.labels = np.load(join(folder, "{}/labels.npy".format(mode)))
|
| 420 |
+
self.image_shape = shape
|
| 421 |
+
|
| 422 |
+
if mode == 'train':
|
| 423 |
+
self.size_categories = [
|
| 424 |
+
(401, 701),
|
| 425 |
+
(701, 255),
|
| 426 |
+
(401, 255)
|
| 427 |
+
]
|
| 428 |
+
elif mode == 'val':
|
| 429 |
+
self.size_categories = [
|
| 430 |
+
(601, 200),
|
| 431 |
+
(200, 255),
|
| 432 |
+
(601, 255)
|
| 433 |
+
]
|
| 434 |
+
|
| 435 |
+
elif mode == 'test':
|
| 436 |
+
self.size_categories = [
|
| 437 |
+
(701, 255),
|
| 438 |
+
(200, 701),
|
| 439 |
+
(200, 255)
|
| 440 |
+
]
|
| 441 |
+
|
| 442 |
+
else:
|
| 443 |
+
raise ValueError("Mode must be 'train', 'val', or 'test'.")
|
| 444 |
+
def __len__(self):
|
| 445 |
+
# We will take cross sections along each dimension, so the length is the sum of all dimensions
|
| 446 |
+
|
| 447 |
+
return sum(self.seises.shape)
|
| 448 |
+
|
| 449 |
+
def random_transform(self, label, seis):
|
| 450 |
+
# Apply the same random transformations to the fault and seismic data
|
| 451 |
+
# Mirror the patch horizontally
|
| 452 |
+
if random.random() > 0.5:
|
| 453 |
+
label = np.fliplr(label)
|
| 454 |
+
seis = np.fliplr(seis)
|
| 455 |
+
|
| 456 |
+
# Mirror the patch vertically
|
| 457 |
+
if random.random() > 0.5:
|
| 458 |
+
label = np.flipud(label)
|
| 459 |
+
seis = np.flipud(seis)
|
| 460 |
+
|
| 461 |
+
return label, seis
|
| 462 |
+
|
| 463 |
+
def __getitem__(self, index):
|
| 464 |
+
|
| 465 |
+
m1, m2, m3 = self.seises.shape
|
| 466 |
+
|
| 467 |
+
if index < m1:
|
| 468 |
+
seis, label = self.seises[index, :, :], self.labels[index, :, :]
|
| 469 |
+
elif index < m1 + m2:
|
| 470 |
+
seis, label = self.seises[:, index - m1, :], self.labels[:, index - m1, :]
|
| 471 |
+
elif index < m1 + m2 + m3:
|
| 472 |
+
seis, label = self.seises[:, :, index - m1 - m2], self.labels[:, :, index - m1 - m2]
|
| 473 |
+
else:
|
| 474 |
+
raise IndexError("Index out of bounds")
|
| 475 |
+
|
| 476 |
+
seis, label = self.random_transform(seis, label)
|
| 477 |
+
seis = (seis - seis.mean()) / (seis.std() + 1e-6)
|
| 478 |
+
|
| 479 |
+
seis, label = torch.tensor(seis.copy(), dtype=torch.float32).unsqueeze(0), torch.tensor(label.copy(), dtype=torch.float32).unsqueeze(0)
|
| 480 |
+
|
| 481 |
+
# label is now shape [1, H, W]
|
| 482 |
+
# we want shape [6, H, W] with each slice being a binary mask depending on the int value of label
|
| 483 |
+
label = label.squeeze(0)
|
| 484 |
+
label = (label == torch.arange(6, device=label.device).view(6, 1, 1)).float()
|
| 485 |
+
|
| 486 |
+
return seis, label
|
| 487 |
+
|
| 488 |
+
class P3DFaciesSet(data.Dataset):
|
| 489 |
+
def __init__(self, folder, shape=[128, 128], mode='train', random_resize = False):
|
| 490 |
+
super().__init__()
|
| 491 |
+
|
| 492 |
+
self.folder = folder
|
| 493 |
+
if not os.path.exists(folder):
|
| 494 |
+
raise FileNotFoundError(f"The folder {folder} does not exist.")
|
| 495 |
+
|
| 496 |
+
self.random_resize = random_resize
|
| 497 |
+
|
| 498 |
+
# Validation set will be validation set from F3DSet
|
| 499 |
+
if mode == 'val': mode = 'train' # TEMPORARY SINCE P3D does not have labelled val set
|
| 500 |
+
|
| 501 |
+
self.mode = mode
|
| 502 |
+
self.image_shape = shape
|
| 503 |
+
|
| 504 |
+
self.s_path = join(folder, "{}/seismic.segy".format(mode))
|
| 505 |
+
self.l_path = join(folder, "{}/labels.segy".format(mode))
|
| 506 |
+
|
| 507 |
+
if mode != 'val':
|
| 508 |
+
with segyio.open(self.s_path, ignore_geometry=True) as seis_file:
|
| 509 |
+
self.seises = seis_file.trace.raw[:]
|
| 510 |
+
|
| 511 |
+
if self.mode in ['val', 'train']:
|
| 512 |
+
with segyio.open(self.l_path, ignore_geometry=True) as label_file:
|
| 513 |
+
self.labels = label_file.trace.raw[:]
|
| 514 |
+
else:
|
| 515 |
+
# Since the test files are unlabeled
|
| 516 |
+
self.labels = np.zeros_like(self.seises)
|
| 517 |
+
else:
|
| 518 |
+
f3d_file_path = "C:\\Users\\abhalekar\\Desktop\\DATASETS\\F3D_facies_DATASET"
|
| 519 |
+
self.seises = np.load(join(f3d_file_path, "val/seismic.npy"))
|
| 520 |
+
self.labels = np.load(join(f3d_file_path, "val/labels.npy"))
|
| 521 |
+
|
| 522 |
+
if mode == 'train':
|
| 523 |
+
m1, m2, m3 = 590, 782, 1006
|
| 524 |
+
elif mode == 'val':
|
| 525 |
+
m1, m2, m3 = 601, 200, 255
|
| 526 |
+
elif mode == 'test_1':
|
| 527 |
+
m1, m2, m3 = 841, 334, 1006
|
| 528 |
+
elif mode == 'test_2':
|
| 529 |
+
m1, m2, m3 = 251, 782, 1006
|
| 530 |
+
else:
|
| 531 |
+
raise ValueError("Mode must be 'train', 'test_2', 'val', or 'test_1'.")
|
| 532 |
+
|
| 533 |
+
self.size_categories = list(permutations([m1, m2, m3], 2))
|
| 534 |
+
|
| 535 |
+
self.seises = self.seises.reshape(m1, m2, m3)
|
| 536 |
+
self.labels = self.labels.reshape(m1, m2, m3)
|
| 537 |
+
|
| 538 |
+
def __len__(self):
|
| 539 |
+
# We will take cross sections along the first 2 dimensions ONLY
|
| 540 |
+
return self.seises.shape[0] + self.seises.shape[1]
|
| 541 |
+
|
| 542 |
+
def _random_transform(self, label, seis):
|
| 543 |
+
# Apply the same random transformations to the fault and seismic data
|
| 544 |
+
# Mirror the patch horizontally
|
| 545 |
+
if random.random() > 0.5:
|
| 546 |
+
label = np.fliplr(label)
|
| 547 |
+
seis = np.fliplr(seis)
|
| 548 |
+
|
| 549 |
+
# Mirror the patch vertically
|
| 550 |
+
if random.random() > 0.5:
|
| 551 |
+
label = np.flipud(label)
|
| 552 |
+
seis = np.flipud(seis)
|
| 553 |
+
|
| 554 |
+
# random rotation to 2D image label,seis
|
| 555 |
+
#r_int = random.randint(0, 3)
|
| 556 |
+
#label = np.rot90(label, r_int)
|
| 557 |
+
#seis = np.rot90(seis, r_int)
|
| 558 |
+
|
| 559 |
+
return label, seis
|
| 560 |
+
|
| 561 |
+
def _random_resize(self, label, seis, min_size = (256, 256)):
|
| 562 |
+
# Randomly resize the label and seismic data
|
| 563 |
+
r_height = random.randint(min_size[0], seis.shape[0])
|
| 564 |
+
r_width = random.randint(min_size[1], seis.shape[1])
|
| 565 |
+
|
| 566 |
+
r_pos_x = random.randint(0, seis.shape[0] - r_height)
|
| 567 |
+
r_pos_y = random.randint(0, seis.shape[1] - r_width)
|
| 568 |
+
|
| 569 |
+
label = label[r_pos_x:r_pos_x + r_height, r_pos_y:r_pos_y + r_width]
|
| 570 |
+
seis = seis[r_pos_x:r_pos_x + r_height, r_pos_y:r_pos_y + r_width]
|
| 571 |
+
|
| 572 |
+
return label, seis
|
| 573 |
+
|
| 574 |
+
def __getitem__(self, index):
|
| 575 |
+
|
| 576 |
+
m1, m2, m3 = self.seises.shape
|
| 577 |
+
|
| 578 |
+
if index < m1:
|
| 579 |
+
seis, label = self.seises[index, :, :], self.labels[index, :, :]
|
| 580 |
+
elif index < m1 + m2:
|
| 581 |
+
seis, label = self.seises[:, index - m1, :], self.labels[:, index - m1, :]
|
| 582 |
+
elif index < m1 + m2 + m3:
|
| 583 |
+
seis, label = self.seises[:, :, index - m1 - m2], self.labels[:, :, index - m1 - m2]
|
| 584 |
+
else:
|
| 585 |
+
raise IndexError("Index out of bounds")
|
| 586 |
+
|
| 587 |
+
seis, label = self._random_transform(seis, label)
|
| 588 |
+
if self.random_resize: seis, label = self._random_resize(seis, label)
|
| 589 |
+
|
| 590 |
+
seis = (seis - seis.mean()) / (seis.std() + 1e-6)
|
| 591 |
+
|
| 592 |
+
seis, label = torch.tensor(seis.copy(), dtype=torch.float32).unsqueeze(0), torch.tensor(label.copy(), dtype=torch.float32).unsqueeze(0)
|
| 593 |
+
|
| 594 |
+
# label is now shape [1, H, W]
|
| 595 |
+
# we want shape [6, H, W] with each slice being a binary mask depending on the int value of label
|
| 596 |
+
label = label.squeeze(0)
|
| 597 |
+
label = (label == torch.arange(1, 7, device=label.device).view(6, 1, 1)).float()
|
| 598 |
+
|
| 599 |
+
return seis, label
|
util/lars.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# LARS optimizer, implementation from MoCo v3:
|
| 8 |
+
# https://github.com/facebookresearch/moco-v3
|
| 9 |
+
# --------------------------------------------------------
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LARS(torch.optim.Optimizer):
|
| 15 |
+
"""
|
| 16 |
+
LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
|
| 19 |
+
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
|
| 20 |
+
super().__init__(params, defaults)
|
| 21 |
+
|
| 22 |
+
@torch.no_grad()
|
| 23 |
+
def step(self):
|
| 24 |
+
for g in self.param_groups:
|
| 25 |
+
for p in g['params']:
|
| 26 |
+
dp = p.grad
|
| 27 |
+
|
| 28 |
+
if dp is None:
|
| 29 |
+
continue
|
| 30 |
+
|
| 31 |
+
if p.ndim > 1: # if not normalization gamma/beta or bias
|
| 32 |
+
dp = dp.add(p, alpha=g['weight_decay'])
|
| 33 |
+
param_norm = torch.norm(p)
|
| 34 |
+
update_norm = torch.norm(dp)
|
| 35 |
+
one = torch.ones_like(param_norm)
|
| 36 |
+
q = torch.where(param_norm > 0.,
|
| 37 |
+
torch.where(update_norm > 0,
|
| 38 |
+
(g['trust_coefficient'] * param_norm / update_norm), one),
|
| 39 |
+
one)
|
| 40 |
+
dp = dp.mul(q)
|
| 41 |
+
|
| 42 |
+
param_state = self.state[p]
|
| 43 |
+
if 'mu' not in param_state:
|
| 44 |
+
param_state['mu'] = torch.zeros_like(p)
|
| 45 |
+
mu = param_state['mu']
|
| 46 |
+
mu.mul_(g['momentum']).add_(dp)
|
| 47 |
+
p.add_(mu, alpha=-g['lr'])
|
util/lr_decay.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# ELECTRA https://github.com/google-research/electra
|
| 9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
|
| 16 |
+
"""
|
| 17 |
+
Parameter groups for layer-wise lr decay
|
| 18 |
+
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
|
| 19 |
+
"""
|
| 20 |
+
param_group_names = {}
|
| 21 |
+
param_groups = {}
|
| 22 |
+
|
| 23 |
+
num_layers = len(model.blocks) + 1
|
| 24 |
+
|
| 25 |
+
layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
|
| 26 |
+
|
| 27 |
+
for n, p in model.named_parameters():
|
| 28 |
+
if not p.requires_grad:
|
| 29 |
+
continue
|
| 30 |
+
|
| 31 |
+
# no decay: all 1D parameters and model specific ones
|
| 32 |
+
if p.ndim == 1 or n in no_weight_decay_list:
|
| 33 |
+
g_decay = "no_decay"
|
| 34 |
+
this_decay = 0.
|
| 35 |
+
else:
|
| 36 |
+
g_decay = "decay"
|
| 37 |
+
this_decay = weight_decay
|
| 38 |
+
|
| 39 |
+
layer_id = get_layer_id_for_vit(n, num_layers)
|
| 40 |
+
group_name = "layer_%d_%s" % (layer_id, g_decay)
|
| 41 |
+
|
| 42 |
+
if group_name not in param_group_names:
|
| 43 |
+
this_scale = layer_scales[layer_id]
|
| 44 |
+
|
| 45 |
+
param_group_names[group_name] = {
|
| 46 |
+
"lr_scale": this_scale,
|
| 47 |
+
"weight_decay": this_decay,
|
| 48 |
+
"params": [],
|
| 49 |
+
}
|
| 50 |
+
param_groups[group_name] = {
|
| 51 |
+
"lr_scale": this_scale,
|
| 52 |
+
"weight_decay": this_decay,
|
| 53 |
+
"params": [],
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
param_group_names[group_name]["params"].append(n)
|
| 57 |
+
param_groups[group_name]["params"].append(p)
|
| 58 |
+
|
| 59 |
+
# print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
|
| 60 |
+
|
| 61 |
+
return list(param_groups.values())
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_layer_id_for_vit(name, num_layers):
|
| 65 |
+
"""
|
| 66 |
+
Assign a parameter with its layer id
|
| 67 |
+
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
|
| 68 |
+
"""
|
| 69 |
+
if name in ['cls_token', 'pos_embed']:
|
| 70 |
+
return 0
|
| 71 |
+
elif name.startswith('patch_embed'):
|
| 72 |
+
return 0
|
| 73 |
+
elif name.startswith('blocks'):
|
| 74 |
+
return int(name.split('.')[1]) + 1
|
| 75 |
+
else:
|
| 76 |
+
return num_layers
|
util/lr_sched.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
| 10 |
+
"""Decay the learning rate with half-cycle cosine after warmup"""
|
| 11 |
+
if epoch < args.warmup_epochs:
|
| 12 |
+
lr = args.lr * epoch / args.warmup_epochs
|
| 13 |
+
else:
|
| 14 |
+
lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
|
| 15 |
+
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
|
| 16 |
+
for param_group in optimizer.param_groups:
|
| 17 |
+
if "lr_scale" in param_group:
|
| 18 |
+
param_group["lr"] = lr * param_group["lr_scale"]
|
| 19 |
+
else:
|
| 20 |
+
param_group["lr"] = lr
|
| 21 |
+
return lr
|
util/metrics.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
"""
|
| 3 |
+
refer to https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/utils/metrics.py
|
| 4 |
+
"""
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
__all__ = ['SegmentationMetric']
|
| 8 |
+
|
| 9 |
+
"""
|
| 10 |
+
confusionMetric # 注意:此处横着代表预测值,竖着代表真实值,与之前介绍的相反
|
| 11 |
+
P\L P N
|
| 12 |
+
P TP FP
|
| 13 |
+
N FN TN
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SegmentationMetric(object):
|
| 18 |
+
def __init__(self, numClass):
|
| 19 |
+
self.numClass = numClass
|
| 20 |
+
self.confusionMatrix = np.zeros((self.numClass,) * 2) # 混淆矩阵(空)
|
| 21 |
+
|
| 22 |
+
def pixelAccuracy(self):
|
| 23 |
+
# return all class overall pixel accuracy 正确的像素占总像素的比例
|
| 24 |
+
# PA = acc = (TP + TN) / (TP + TN + FP + TN)
|
| 25 |
+
acc = np.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum()
|
| 26 |
+
return acc
|
| 27 |
+
|
| 28 |
+
def classPixelAccuracy(self):
|
| 29 |
+
# return each category pixel accuracy(A more accurate way to call it precision)
|
| 30 |
+
# acc = (TP) / TP + FP
|
| 31 |
+
classAcc = np.diag(self.confusionMatrix) / self.confusionMatrix.sum(axis=1)
|
| 32 |
+
return classAcc # 返回的是一个列表值,如:[0.90, 0.80, 0.96],表示类别1 2 3各类别的预测准确率
|
| 33 |
+
|
| 34 |
+
def meanPixelAccuracy(self):
|
| 35 |
+
"""
|
| 36 |
+
Mean Pixel Accuracy(MPA,均像素精度):是PA的一种简单提升,计算每个类内被正确分类像素数的比例,之后求所有类的平均。
|
| 37 |
+
:return:
|
| 38 |
+
"""
|
| 39 |
+
classAcc = self.classPixelAccuracy()
|
| 40 |
+
meanAcc = np.nanmean(classAcc) # np.nanmean 求平均值,nan表示遇到Nan类型,其值取为0
|
| 41 |
+
return meanAcc # 返回单个值,如:np.nanmean([0.90, 0.80, 0.96, nan, nan]) = (0.90 + 0.80 + 0.96) / 3 = 0.89
|
| 42 |
+
|
| 43 |
+
def IntersectionOverUnion(self):
|
| 44 |
+
# Intersection = TP Union = TP + FP + FN
|
| 45 |
+
# IoU = TP / (TP + FP + FN)
|
| 46 |
+
intersection = np.diag(self.confusionMatrix) # 取对角元素的值,返回列表
|
| 47 |
+
union = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag(
|
| 48 |
+
self.confusionMatrix) # axis = 1表示混淆矩阵行的值,返回列表; axis = 0表示取混淆矩阵列的值,返回列表
|
| 49 |
+
IoU = intersection / union # 返回列表,其值为各个类别的IoU
|
| 50 |
+
return IoU
|
| 51 |
+
|
| 52 |
+
def meanIntersectionOverUnion(self):
|
| 53 |
+
mIoU = np.nanmean(self.IntersectionOverUnion()) # 求各类别IoU的平均
|
| 54 |
+
return mIoU
|
| 55 |
+
|
| 56 |
+
def genConfusionMatrix(self, imgPredict, imgLabel): #
|
| 57 |
+
"""
|
| 58 |
+
同FCN中score.py的fast_hist()函数,计算混淆矩阵
|
| 59 |
+
:param imgPredict:
|
| 60 |
+
:param imgLabel:
|
| 61 |
+
:return: 混淆矩阵
|
| 62 |
+
"""
|
| 63 |
+
# remove classes from unlabeled pixels in gt image and predict
|
| 64 |
+
mask = (imgLabel >= 0) & (imgLabel < self.numClass)
|
| 65 |
+
label = self.numClass * imgLabel[mask] + imgPredict[mask]
|
| 66 |
+
count = np.bincount(label, minlength=self.numClass ** 2)
|
| 67 |
+
confusionMatrix = count.reshape(self.numClass, self.numClass)
|
| 68 |
+
# print(confusionMatrix)
|
| 69 |
+
return confusionMatrix
|
| 70 |
+
|
| 71 |
+
def Frequency_Weighted_Intersection_over_Union(self):
|
| 72 |
+
"""
|
| 73 |
+
FWIoU,频权交并比:为MIoU的一种提升,这种方法根据每个类出现的频率为其设置权重。
|
| 74 |
+
FWIOU = [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)]
|
| 75 |
+
"""
|
| 76 |
+
freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
|
| 77 |
+
iu = np.diag(self.confusion_matrix) / (
|
| 78 |
+
np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
|
| 79 |
+
np.diag(self.confusion_matrix))
|
| 80 |
+
FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
|
| 81 |
+
return FWIoU
|
| 82 |
+
|
| 83 |
+
def addBatch(self, imgPredict, imgLabel):
|
| 84 |
+
assert imgPredict.shape == imgLabel.shape
|
| 85 |
+
self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel) # 得到混淆矩阵
|
| 86 |
+
return self.confusionMatrix
|
| 87 |
+
|
| 88 |
+
def reset(self):
|
| 89 |
+
self.confusionMatrix = np.zeros((self.numClass, self.numClass))
|
| 90 |
+
|
util/misc.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 9 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
import builtins
|
| 13 |
+
import datetime
|
| 14 |
+
import os
|
| 15 |
+
import time
|
| 16 |
+
from collections import defaultdict, deque
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.distributed as dist
|
| 21 |
+
from torch import inf
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SmoothedValue(object):
|
| 25 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 26 |
+
window or the global series average.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, window_size=20, fmt=None):
|
| 30 |
+
if fmt is None:
|
| 31 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 32 |
+
self.deque = deque(maxlen=window_size)
|
| 33 |
+
self.total = 0.0
|
| 34 |
+
self.count = 0
|
| 35 |
+
self.fmt = fmt
|
| 36 |
+
|
| 37 |
+
def update(self, value, n=1):
|
| 38 |
+
self.deque.append(value)
|
| 39 |
+
self.count += n
|
| 40 |
+
self.total += value * n
|
| 41 |
+
|
| 42 |
+
def synchronize_between_processes(self):
|
| 43 |
+
"""
|
| 44 |
+
Warning: does not synchronize the deque!
|
| 45 |
+
"""
|
| 46 |
+
if not is_dist_avail_and_initialized():
|
| 47 |
+
return
|
| 48 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
| 49 |
+
dist.barrier()
|
| 50 |
+
dist.all_reduce(t)
|
| 51 |
+
t = t.tolist()
|
| 52 |
+
self.count = int(t[0])
|
| 53 |
+
self.total = t[1]
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def median(self):
|
| 57 |
+
d = torch.tensor(list(self.deque))
|
| 58 |
+
return d.median().item()
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def avg(self):
|
| 62 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 63 |
+
return d.mean().item()
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def global_avg(self):
|
| 67 |
+
return self.total / self.count
|
| 68 |
+
|
| 69 |
+
@property
|
| 70 |
+
def max(self):
|
| 71 |
+
return max(self.deque)
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def value(self):
|
| 75 |
+
return self.deque[-1]
|
| 76 |
+
|
| 77 |
+
def __str__(self):
|
| 78 |
+
return self.fmt.format(
|
| 79 |
+
median=self.median,
|
| 80 |
+
avg=self.avg,
|
| 81 |
+
global_avg=self.global_avg,
|
| 82 |
+
max=self.max,
|
| 83 |
+
value=self.value)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class MetricLogger(object):
|
| 87 |
+
def __init__(self, delimiter="\t"):
|
| 88 |
+
self.meters = defaultdict(SmoothedValue)
|
| 89 |
+
self.delimiter = delimiter
|
| 90 |
+
|
| 91 |
+
def update(self, **kwargs):
|
| 92 |
+
for k, v in kwargs.items():
|
| 93 |
+
if v is None:
|
| 94 |
+
continue
|
| 95 |
+
if isinstance(v, torch.Tensor):
|
| 96 |
+
v = v.item()
|
| 97 |
+
assert isinstance(v, (float, int))
|
| 98 |
+
self.meters[k].update(v)
|
| 99 |
+
|
| 100 |
+
def __getattr__(self, attr):
|
| 101 |
+
if attr in self.meters:
|
| 102 |
+
return self.meters[attr]
|
| 103 |
+
if attr in self.__dict__:
|
| 104 |
+
return self.__dict__[attr]
|
| 105 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
| 106 |
+
type(self).__name__, attr))
|
| 107 |
+
|
| 108 |
+
def __str__(self):
|
| 109 |
+
loss_str = []
|
| 110 |
+
for name, meter in self.meters.items():
|
| 111 |
+
loss_str.append(
|
| 112 |
+
"{}: {}".format(name, str(meter))
|
| 113 |
+
)
|
| 114 |
+
return self.delimiter.join(loss_str)
|
| 115 |
+
|
| 116 |
+
def synchronize_between_processes(self):
|
| 117 |
+
for meter in self.meters.values():
|
| 118 |
+
meter.synchronize_between_processes()
|
| 119 |
+
|
| 120 |
+
def add_meter(self, name, meter):
|
| 121 |
+
self.meters[name] = meter
|
| 122 |
+
|
| 123 |
+
def log_every(self, iterable, print_freq, header=None):
|
| 124 |
+
i = 0
|
| 125 |
+
if not header:
|
| 126 |
+
header = ''
|
| 127 |
+
start_time = time.time()
|
| 128 |
+
end = time.time()
|
| 129 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
| 130 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
| 131 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
| 132 |
+
log_msg = [
|
| 133 |
+
header,
|
| 134 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 135 |
+
'eta: {eta}',
|
| 136 |
+
'{meters}',
|
| 137 |
+
'time: {time}',
|
| 138 |
+
'data: {data}'
|
| 139 |
+
]
|
| 140 |
+
if torch.cuda.is_available():
|
| 141 |
+
log_msg.append('max mem: {memory:.0f}')
|
| 142 |
+
log_msg = self.delimiter.join(log_msg)
|
| 143 |
+
MB = 1024.0 * 1024.0
|
| 144 |
+
for obj in iterable:
|
| 145 |
+
data_time.update(time.time() - end)
|
| 146 |
+
yield obj
|
| 147 |
+
iter_time.update(time.time() - end)
|
| 148 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
| 149 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 150 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 151 |
+
if torch.cuda.is_available():
|
| 152 |
+
print(log_msg.format(
|
| 153 |
+
i, len(iterable), eta=eta_string,
|
| 154 |
+
meters=str(self),
|
| 155 |
+
time=str(iter_time), data=str(data_time),
|
| 156 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
| 157 |
+
else:
|
| 158 |
+
print(log_msg.format(
|
| 159 |
+
i, len(iterable), eta=eta_string,
|
| 160 |
+
meters=str(self),
|
| 161 |
+
time=str(iter_time), data=str(data_time)))
|
| 162 |
+
i += 1
|
| 163 |
+
end = time.time()
|
| 164 |
+
total_time = time.time() - start_time
|
| 165 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 166 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
| 167 |
+
header, total_time_str, total_time / len(iterable)))
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def setup_for_distributed(is_master):
|
| 171 |
+
"""
|
| 172 |
+
This function disables printing when not in master process
|
| 173 |
+
"""
|
| 174 |
+
builtin_print = builtins.print
|
| 175 |
+
|
| 176 |
+
def print(*args, **kwargs):
|
| 177 |
+
force = kwargs.pop('force', False)
|
| 178 |
+
force = force or (get_world_size() > 8)
|
| 179 |
+
if is_master or force:
|
| 180 |
+
now = datetime.datetime.now().time()
|
| 181 |
+
builtin_print('[{}] '.format(now), end='') # print with time stamp
|
| 182 |
+
builtin_print(*args, **kwargs)
|
| 183 |
+
|
| 184 |
+
builtins.print = print
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def is_dist_avail_and_initialized():
|
| 188 |
+
if not dist.is_available():
|
| 189 |
+
return False
|
| 190 |
+
if not dist.is_initialized():
|
| 191 |
+
return False
|
| 192 |
+
return True
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def get_world_size():
|
| 196 |
+
if not is_dist_avail_and_initialized():
|
| 197 |
+
return 1
|
| 198 |
+
return dist.get_world_size()
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def get_rank():
|
| 202 |
+
if not is_dist_avail_and_initialized():
|
| 203 |
+
return 0
|
| 204 |
+
return dist.get_rank()
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def is_main_process():
|
| 208 |
+
return get_rank() == 0
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def save_on_master(*args, **kwargs):
|
| 212 |
+
if is_main_process():
|
| 213 |
+
torch.save(*args, **kwargs)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def init_distributed_mode(args):
|
| 217 |
+
if args.dist_on_itp:
|
| 218 |
+
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
| 219 |
+
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
| 220 |
+
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
| 221 |
+
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
|
| 222 |
+
os.environ['LOCAL_RANK'] = str(args.gpu)
|
| 223 |
+
os.environ['RANK'] = str(args.rank)
|
| 224 |
+
os.environ['WORLD_SIZE'] = str(args.world_size)
|
| 225 |
+
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
|
| 226 |
+
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 227 |
+
args.rank = int(os.environ["RANK"])
|
| 228 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
| 229 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
| 230 |
+
elif 'SLURM_PROCID' in os.environ:
|
| 231 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
| 232 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
| 233 |
+
else:
|
| 234 |
+
print('Not using distributed mode')
|
| 235 |
+
setup_for_distributed(is_master=True) # hack
|
| 236 |
+
args.distributed = False
|
| 237 |
+
return
|
| 238 |
+
|
| 239 |
+
args.distributed = True
|
| 240 |
+
|
| 241 |
+
torch.cuda.set_device(args.gpu)
|
| 242 |
+
args.dist_backend = 'nccl'
|
| 243 |
+
print('| distributed init (rank {}): {}, gpu {}'.format(
|
| 244 |
+
args.rank, args.dist_url, args.gpu), flush=True)
|
| 245 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
| 246 |
+
world_size=args.world_size, rank=args.rank)
|
| 247 |
+
torch.distributed.barrier()
|
| 248 |
+
setup_for_distributed(args.rank == 0)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class NativeScalerWithGradNormCount:
|
| 252 |
+
state_dict_key = "amp_scaler"
|
| 253 |
+
|
| 254 |
+
def __init__(self):
|
| 255 |
+
self._scaler = torch.cuda.amp.GradScaler()
|
| 256 |
+
|
| 257 |
+
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
|
| 258 |
+
self._scaler.scale(loss).backward(create_graph=create_graph)
|
| 259 |
+
if update_grad:
|
| 260 |
+
if clip_grad is not None:
|
| 261 |
+
assert parameters is not None
|
| 262 |
+
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
|
| 263 |
+
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
|
| 264 |
+
else:
|
| 265 |
+
self._scaler.unscale_(optimizer)
|
| 266 |
+
norm = get_grad_norm_(parameters)
|
| 267 |
+
self._scaler.step(optimizer)
|
| 268 |
+
self._scaler.update()
|
| 269 |
+
else:
|
| 270 |
+
norm = None
|
| 271 |
+
return norm
|
| 272 |
+
|
| 273 |
+
def state_dict(self):
|
| 274 |
+
return self._scaler.state_dict()
|
| 275 |
+
|
| 276 |
+
def load_state_dict(self, state_dict):
|
| 277 |
+
self._scaler.load_state_dict(state_dict)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
|
| 281 |
+
if isinstance(parameters, torch.Tensor):
|
| 282 |
+
parameters = [parameters]
|
| 283 |
+
parameters = [p for p in parameters if p.grad is not None]
|
| 284 |
+
norm_type = float(norm_type)
|
| 285 |
+
if len(parameters) == 0:
|
| 286 |
+
return torch.tensor(0.)
|
| 287 |
+
device = parameters[0].grad.device
|
| 288 |
+
if norm_type == inf:
|
| 289 |
+
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
|
| 290 |
+
else:
|
| 291 |
+
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
| 292 |
+
return total_norm
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
|
| 296 |
+
output_dir = Path(args.output_dir)
|
| 297 |
+
epoch_name = str(epoch)
|
| 298 |
+
if loss_scaler is not None:
|
| 299 |
+
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
|
| 300 |
+
for checkpoint_path in checkpoint_paths:
|
| 301 |
+
to_save = {
|
| 302 |
+
'model': model_without_ddp.state_dict(),
|
| 303 |
+
'optimizer': optimizer.state_dict(),
|
| 304 |
+
'epoch': epoch,
|
| 305 |
+
'scaler': loss_scaler.state_dict(),
|
| 306 |
+
'args': args,
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
save_on_master(to_save, checkpoint_path)
|
| 310 |
+
else:
|
| 311 |
+
client_state = {'epoch': epoch}
|
| 312 |
+
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def load_model(args, model_without_ddp, optimizer, loss_scaler):
|
| 316 |
+
if args.resume:
|
| 317 |
+
if args.resume.startswith('https'):
|
| 318 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 319 |
+
args.resume, map_location='cpu', check_hash=True)
|
| 320 |
+
else:
|
| 321 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
| 322 |
+
model_without_ddp.load_state_dict(checkpoint['model'])
|
| 323 |
+
print("Resume checkpoint %s" % args.resume)
|
| 324 |
+
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
|
| 325 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 326 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
| 327 |
+
if 'scaler' in checkpoint:
|
| 328 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
| 329 |
+
print("With optim & sched!")
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def all_reduce_mean(x):
|
| 333 |
+
world_size = get_world_size()
|
| 334 |
+
if world_size > 1:
|
| 335 |
+
x_reduce = torch.tensor(x).cuda()
|
| 336 |
+
dist.all_reduce(x_reduce)
|
| 337 |
+
x_reduce /= world_size
|
| 338 |
+
return x_reduce.item()
|
| 339 |
+
else:
|
| 340 |
+
return x
|
util/msssim.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from math import exp
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def gaussian(window_size, sigma):
|
| 7 |
+
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
| 8 |
+
return gauss/gauss.sum()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_window(window_size, channel=1):
|
| 12 |
+
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
| 13 |
+
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
| 14 |
+
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
| 15 |
+
return window
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
|
| 19 |
+
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
|
| 20 |
+
if val_range is None:
|
| 21 |
+
if torch.max(img1) > 128:
|
| 22 |
+
max_val = 255
|
| 23 |
+
else:
|
| 24 |
+
max_val = 1
|
| 25 |
+
|
| 26 |
+
if torch.min(img1) < -0.5:
|
| 27 |
+
min_val = -1
|
| 28 |
+
else:
|
| 29 |
+
min_val = 0
|
| 30 |
+
L = max_val - min_val
|
| 31 |
+
else:
|
| 32 |
+
L = val_range
|
| 33 |
+
|
| 34 |
+
padd = 0
|
| 35 |
+
(_, channel, height, width) = img1.size()
|
| 36 |
+
if window is None:
|
| 37 |
+
real_size = min(window_size, height, width)
|
| 38 |
+
window = create_window(real_size, channel=channel).to(img1.device)
|
| 39 |
+
|
| 40 |
+
mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
|
| 41 |
+
mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
|
| 42 |
+
|
| 43 |
+
mu1_sq = mu1.pow(2)
|
| 44 |
+
mu2_sq = mu2.pow(2)
|
| 45 |
+
mu1_mu2 = mu1 * mu2
|
| 46 |
+
|
| 47 |
+
sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
|
| 48 |
+
sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
|
| 49 |
+
sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
|
| 50 |
+
|
| 51 |
+
C1 = (0.01 * L) ** 2
|
| 52 |
+
C2 = (0.03 * L) ** 2
|
| 53 |
+
|
| 54 |
+
v1 = 2.0 * sigma12 + C2
|
| 55 |
+
v2 = sigma1_sq + sigma2_sq + C2
|
| 56 |
+
cs = torch.mean(v1 / v2) # contrast sensitivity
|
| 57 |
+
|
| 58 |
+
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
|
| 59 |
+
|
| 60 |
+
if size_average:
|
| 61 |
+
ret = ssim_map.mean()
|
| 62 |
+
else:
|
| 63 |
+
ret = ssim_map.mean(1).mean(1).mean(1)
|
| 64 |
+
|
| 65 |
+
if full:
|
| 66 |
+
return ret, cs
|
| 67 |
+
return ret
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=True):
|
| 71 |
+
device = img1.device
|
| 72 |
+
weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device)
|
| 73 |
+
levels = weights.size()[0]
|
| 74 |
+
mssim = []
|
| 75 |
+
mcs = []
|
| 76 |
+
for _ in range(levels):
|
| 77 |
+
sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range)
|
| 78 |
+
mssim.append(sim)
|
| 79 |
+
mcs.append(cs)
|
| 80 |
+
|
| 81 |
+
img1 = F.avg_pool2d(img1, (2, 2))
|
| 82 |
+
img2 = F.avg_pool2d(img2, (2, 2))
|
| 83 |
+
|
| 84 |
+
mssim = torch.stack(mssim)
|
| 85 |
+
mcs = torch.stack(mcs)
|
| 86 |
+
|
| 87 |
+
# Normalize (to avoid NaNs during training unstable models, not compliant with original definition)
|
| 88 |
+
if normalize:
|
| 89 |
+
mssim = (mssim + 1) / 2
|
| 90 |
+
mcs = (mcs + 1) / 2
|
| 91 |
+
|
| 92 |
+
pow1 = mcs ** weights
|
| 93 |
+
pow2 = mssim ** weights
|
| 94 |
+
# From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/
|
| 95 |
+
output = torch.prod(pow1[:-1] * pow2[-1])
|
| 96 |
+
return output
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# Classes to re-use window
|
| 100 |
+
class SSIM(torch.nn.Module):
|
| 101 |
+
def __init__(self, window_size=11, size_average=True, val_range=None):
|
| 102 |
+
super(SSIM, self).__init__()
|
| 103 |
+
self.window_size = window_size
|
| 104 |
+
self.size_average = size_average
|
| 105 |
+
self.val_range = val_range
|
| 106 |
+
|
| 107 |
+
# Assume 1 channel for SSIM
|
| 108 |
+
self.channel = 1
|
| 109 |
+
self.window = create_window(window_size)
|
| 110 |
+
|
| 111 |
+
def forward(self, img1, img2):
|
| 112 |
+
(_, channel, _, _) = img1.size()
|
| 113 |
+
|
| 114 |
+
if channel == self.channel and self.window.dtype == img1.dtype:
|
| 115 |
+
window = self.window
|
| 116 |
+
else:
|
| 117 |
+
window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
|
| 118 |
+
self.window = window
|
| 119 |
+
self.channel = channel
|
| 120 |
+
|
| 121 |
+
return 1 - ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
|
| 122 |
+
|
| 123 |
+
class MSSSIM(torch.nn.Module):
|
| 124 |
+
def __init__(self, window_size=11, size_average=True, channel=1):
|
| 125 |
+
super(MSSSIM, self).__init__()
|
| 126 |
+
self.window_size = window_size
|
| 127 |
+
self.size_average = size_average
|
| 128 |
+
self.channel = channel
|
| 129 |
+
|
| 130 |
+
def forward(self, img1, img2):
|
| 131 |
+
# TODO: store window between calls if possible
|
| 132 |
+
return 1 - msssim(img1, img2, window_size=self.window_size, size_average=self.size_average)
|
| 133 |
+
|
| 134 |
+
class PSNR(torch.nn.Module):
|
| 135 |
+
def __init__(self):
|
| 136 |
+
super(PSNR, self).__init__()
|
| 137 |
+
|
| 138 |
+
def torchPSNR(self,tar_img, prd_img):
|
| 139 |
+
imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
|
| 140 |
+
rmse = (imdff**2).mean().sqrt()
|
| 141 |
+
ps = 20*torch.log10(1/rmse)
|
| 142 |
+
return ps
|
| 143 |
+
|
| 144 |
+
def forward(self, img1, img2):
|
| 145 |
+
# TODO: store window between calls if possible
|
| 146 |
+
return self.torchPSNR(img1, img2)
|
util/pos_embed.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# Position embedding utils
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
# --------------------------------------------------------
|
| 15 |
+
# 2D sine-cosine position embedding
|
| 16 |
+
# References:
|
| 17 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
| 18 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
| 19 |
+
# --------------------------------------------------------
|
| 20 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
| 21 |
+
"""
|
| 22 |
+
grid_size: int of the grid height and width
|
| 23 |
+
return:
|
| 24 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 25 |
+
"""
|
| 26 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 27 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 28 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 29 |
+
grid = np.stack(grid, axis=0)
|
| 30 |
+
|
| 31 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 32 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 33 |
+
if cls_token:
|
| 34 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 35 |
+
return pos_embed
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 39 |
+
assert embed_dim % 2 == 0
|
| 40 |
+
|
| 41 |
+
# use half of dimensions to encode grid_h
|
| 42 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 43 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 44 |
+
|
| 45 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 46 |
+
return emb
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 50 |
+
"""
|
| 51 |
+
embed_dim: output dimension for each position
|
| 52 |
+
pos: a list of positions to be encoded: size (M,)
|
| 53 |
+
out: (M, D)
|
| 54 |
+
"""
|
| 55 |
+
assert embed_dim % 2 == 0
|
| 56 |
+
omega = np.arange(embed_dim // 2, dtype=np.float)
|
| 57 |
+
omega /= embed_dim / 2.
|
| 58 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 59 |
+
|
| 60 |
+
pos = pos.reshape(-1) # (M,)
|
| 61 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 62 |
+
|
| 63 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 64 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 65 |
+
|
| 66 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 67 |
+
return emb
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# --------------------------------------------------------
|
| 71 |
+
# Interpolate position embeddings for high-resolution
|
| 72 |
+
# References:
|
| 73 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 74 |
+
# --------------------------------------------------------
|
| 75 |
+
def interpolate_pos_embed(model, checkpoint_model,newsize1=None,newsize2=None):
|
| 76 |
+
if 'pos_embed' in checkpoint_model:
|
| 77 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
| 78 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
| 79 |
+
num_patches = model.patch_embed.num_patches
|
| 80 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
| 81 |
+
# height (== width) for the checkpoint position embedding
|
| 82 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
| 83 |
+
# height (== width) for the new position embedding
|
| 84 |
+
new_size = int(num_patches ** 0.5)
|
| 85 |
+
# class_token and dist_token are kept unchanged
|
| 86 |
+
if orig_size != new_size:
|
| 87 |
+
if newsize1 == None:
|
| 88 |
+
newsize1,newsize2 = new_size,new_size
|
| 89 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, newsize1, newsize2))
|
| 90 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
| 91 |
+
# only the position tokens are interpolated
|
| 92 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
| 93 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
| 94 |
+
pos_tokens = torch.nn.functional.interpolate(
|
| 95 |
+
pos_tokens, size=(newsize1, newsize2), mode='bicubic', align_corners=False)
|
| 96 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 97 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
| 98 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
| 99 |
+
# elif orig_size > new_size:
|
| 100 |
+
# print("Position generate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
| 101 |
+
# pos_tokens = get_2d_sincos_pos_embed(embedding_size, new_size, cls_token=True)
|
| 102 |
+
# pos_tokens = torch.from_numpy(pos_tokens).float().unsqueeze(0)
|
| 103 |
+
# checkpoint_model['pos_embed'] = pos_tokens
|
| 104 |
+
|
util/pos_embedtest.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# Position embedding utils
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
# --------------------------------------------------------
|
| 15 |
+
# 2D sine-cosine position embedding
|
| 16 |
+
# References:
|
| 17 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
| 18 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
| 19 |
+
# --------------------------------------------------------
|
| 20 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
|
| 21 |
+
"""
|
| 22 |
+
grid_size: int of the grid height and width
|
| 23 |
+
return:
|
| 24 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 25 |
+
"""
|
| 26 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 27 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 28 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 29 |
+
grid = np.stack(grid, axis=0)
|
| 30 |
+
|
| 31 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 32 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 33 |
+
if cls_token:
|
| 34 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 35 |
+
return pos_embed
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 39 |
+
assert embed_dim % 2 == 0
|
| 40 |
+
|
| 41 |
+
# use half of dimensions to encode grid_h
|
| 42 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 43 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 44 |
+
|
| 45 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 46 |
+
return emb
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 50 |
+
"""
|
| 51 |
+
embed_dim: output dimension for each position
|
| 52 |
+
pos: a list of positions to be encoded: size (M,)
|
| 53 |
+
out: (M, D)
|
| 54 |
+
"""
|
| 55 |
+
assert embed_dim % 2 == 0
|
| 56 |
+
omega = np.arange(embed_dim // 2, dtype=np.float)
|
| 57 |
+
omega /= embed_dim / 2.
|
| 58 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 59 |
+
|
| 60 |
+
pos = pos.reshape(-1) # (M,)
|
| 61 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 62 |
+
|
| 63 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 64 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 65 |
+
|
| 66 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 67 |
+
return emb
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# --------------------------------------------------------
|
| 71 |
+
# Interpolate position embeddings for high-resolution
|
| 72 |
+
# References:
|
| 73 |
+
# DeiT: https://github.com/facebookresearch/deit
|
| 74 |
+
# --------------------------------------------------------
|
| 75 |
+
def interpolate_pos_embed(model, checkpoint_model,newsize1=None,newsize2=None):
|
| 76 |
+
if 'pos_embed' in checkpoint_model:
|
| 77 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
| 78 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
| 79 |
+
num_patches = model.patch_embed.num_patches
|
| 80 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
| 81 |
+
# height (== width) for the checkpoint position embedding
|
| 82 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
| 83 |
+
# height (== width) for the new position embedding
|
| 84 |
+
new_size = int(num_patches ** 0.5)
|
| 85 |
+
# class_token and dist_token are kept unchanged
|
| 86 |
+
if orig_size != new_size:
|
| 87 |
+
if newsize1 == None:
|
| 88 |
+
newsize1,newsize2 = new_size,new_size
|
| 89 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, newsize1, newsize2))
|
| 90 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
| 91 |
+
# only the position tokens are interpolated
|
| 92 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
| 93 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
| 94 |
+
pos_tokens = torch.nn.functional.interpolate(
|
| 95 |
+
pos_tokens, size=(newsize1, newsize2), mode='bicubic', align_corners=False)
|
| 96 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 97 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
| 98 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
| 99 |
+
|
| 100 |
+
def interpolate_dec_embed(model, checkpoint_model):
|
| 101 |
+
if 'decoder_pos_embed' in checkpoint_model:
|
| 102 |
+
pos_embed_checkpoint = checkpoint_model['decoder_pos_embed']
|
| 103 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
| 104 |
+
num_patches = model.decoder_pos_embed.num_patches
|
| 105 |
+
num_extra_tokens = model.decoder_pos_embed.shape[-2] - num_patches
|
| 106 |
+
# height (== width) for the checkpoint position embedding
|
| 107 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
| 108 |
+
# height (== width) for the new position embedding
|
| 109 |
+
new_size = int(num_patches ** 0.5)
|
| 110 |
+
# class_token and dist_token are kept unchanged
|
| 111 |
+
if orig_size != new_size:
|
| 112 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
| 113 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
| 114 |
+
# only the position tokens are interpolated
|
| 115 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
| 116 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
| 117 |
+
pos_tokens = torch.nn.functional.interpolate(
|
| 118 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
| 119 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
| 120 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
| 121 |
+
checkpoint_model['decoder_pos_embed'] = new_pos_embed
|
| 122 |
+
# elif orig_size > new_size:
|
| 123 |
+
# print("Position generate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
| 124 |
+
# pos_tokens = get_2d_sincos_pos_embed(embedding_size, new_size, cls_token=True)
|
| 125 |
+
# pos_tokens = torch.from_numpy(pos_tokens).float().unsqueeze(0)
|
| 126 |
+
# checkpoint_model['pos_embed'] = pos_tokens
|
| 127 |
+
|
util/post_processing.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import PIL.Image as Image
|
| 4 |
+
import torchvision.transforms as transforms
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from typing import Optional, Tuple, Union
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def morphological_open(image: torch.Tensor, kernel_size: int = 3) -> torch.Tensor:
|
| 10 |
+
"""
|
| 11 |
+
Perform morphological opening on a 2D torch tensor (image).
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
image (torch.Tensor): image to open
|
| 15 |
+
kernel_size (int): size of the structuring element - roughly the size of hole to be opened
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
torch.Tensor: The opened image.
|
| 19 |
+
"""
|
| 20 |
+
kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32, device=image.device)
|
| 21 |
+
eroded = F.conv2d(image.unsqueeze(0), kernel, stride=1, padding=kernel_size // 2)
|
| 22 |
+
eroded = (eroded > 0).float()
|
| 23 |
+
dilated = F.conv2d(eroded, kernel, stride=1, padding=kernel_size // 2)
|
| 24 |
+
return (dilated > 0).float()
|
| 25 |
+
|
| 26 |
+
def morphological_close(image: torch.Tensor, kernel_size: int = 3) -> torch.Tensor:
|
| 27 |
+
"""
|
| 28 |
+
Perform morphological closing on a 2D torch tensor (image).
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
image (torch.Tensor): image to close
|
| 32 |
+
kernel_size (int): size of the structuring element - roughly the size of hole to be closed
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
torch.Tensor: The closed image.
|
| 36 |
+
"""
|
| 37 |
+
kernel = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32, device=image.device)
|
| 38 |
+
dilated = F.conv2d(image.unsqueeze(0), kernel, stride=1, padding=kernel_size // 2)
|
| 39 |
+
dilated = (dilated > 0).float()
|
| 40 |
+
eroded = F.conv2d(dilated, kernel, stride=1, padding=kernel_size // 2)
|
| 41 |
+
return (eroded > 0).float()
|
| 42 |
+
|
| 43 |
+
def gaussian_convolve(image: torch.Tensor, kernel_size: int = 5, sigma: float = 1.0) -> torch.Tensor:
|
| 44 |
+
"""
|
| 45 |
+
Gaussian Convolution to smooth image
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
image (torch.Tensor): image to convolve
|
| 49 |
+
kernel_size (int): size of the Gaussian kernel
|
| 50 |
+
sigma (float): standard deviation of the Gaussian distribution
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
torch.Tensor: The convolved image.
|
| 54 |
+
"""
|
| 55 |
+
x = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=torch.float32)
|
| 56 |
+
y = torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=torch.float32)
|
| 57 |
+
x, y = torch.meshgrid(x, y)
|
| 58 |
+
kernel = torch.exp(-(x**2 + y**2) / (2 * sigma**2))
|
| 59 |
+
kernel = kernel / kernel.sum()
|
| 60 |
+
# Apply the Gaussian kernel
|
| 61 |
+
return F.conv2d(image.unsqueeze(0), kernel.unsqueeze(0).unsqueeze(0), stride=1, padding=kernel_size // 2)
|
| 62 |
+
|
| 63 |
+
def hysteresis_filter(image: torch.Tensor, low_threshold: float, high_threshold: float) -> torch.Tensor:
|
| 64 |
+
"""
|
| 65 |
+
Hysteresis Filter Function - for Canny Edge detection
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
image (torch.Tensor): image to process
|
| 69 |
+
low_threshold (float): low threshold for hysteresis
|
| 70 |
+
high_threshold (float): high threshold for hysteresis
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
edge (torch.Tensor): The edges detected in the image.
|
| 74 |
+
|
| 75 |
+
"""
|
| 76 |
+
edges = (image > high_threshold).float()
|
| 77 |
+
# Perform hysteresis thresholding
|
| 78 |
+
edges = torch.where(image > low_threshold, edges, 0)
|
| 79 |
+
return edges
|
| 80 |
+
|
| 81 |
+
def non_maxima_suppression_2d(
|
| 82 |
+
image: torch.Tensor,
|
| 83 |
+
kernel_size: int = 3,
|
| 84 |
+
threshold: Optional[float] = None,
|
| 85 |
+
return_mask: bool = False
|
| 86 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 87 |
+
"""
|
| 88 |
+
Perform non-maxima suppression on a 2D torch tensor (image).
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
image (torch.Tensor): Input tensor of shape (H, W) or (B, C, H, W) or (C, H, W)
|
| 92 |
+
kernel_size (int): Size of the local neighborhood for maxima detection (default: 3)
|
| 93 |
+
threshold (float, optional): Minimum value threshold for considering pixels
|
| 94 |
+
return_mask (bool): If True, return both suppressed image and binary mask
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
torch.Tensor: Image with non-maxima suppressed
|
| 98 |
+
torch.Tensor (optional): Binary mask of local maxima if return_mask=True
|
| 99 |
+
"""
|
| 100 |
+
original_shape = image.shape
|
| 101 |
+
|
| 102 |
+
# Handle different input shapes
|
| 103 |
+
if len(image.shape) == 2: # (H, W)
|
| 104 |
+
image = image.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
|
| 105 |
+
elif len(image.shape) == 3: # (C, H, W)
|
| 106 |
+
image = image.unsqueeze(0) # (1, C, H, W)
|
| 107 |
+
elif len(image.shape) == 4: # (B, C, H, W)
|
| 108 |
+
pass
|
| 109 |
+
else:
|
| 110 |
+
raise ValueError(f"Unsupported tensor shape: {original_shape}")
|
| 111 |
+
|
| 112 |
+
batch_size, channels, height, width = image.shape
|
| 113 |
+
|
| 114 |
+
# Apply threshold if specified
|
| 115 |
+
if threshold is not None:
|
| 116 |
+
image = torch.where(image >= threshold, image, torch.tensor(0.0, device=image.device))
|
| 117 |
+
|
| 118 |
+
# Perform max pooling to find local maxima
|
| 119 |
+
padding = kernel_size // 2
|
| 120 |
+
max_pooled = F.max_pool2d(image, kernel_size=kernel_size, stride=1, padding=padding)
|
| 121 |
+
|
| 122 |
+
# Create mask where original values equal max pooled values (local maxima)
|
| 123 |
+
mask = (image == max_pooled) & (image > 0)
|
| 124 |
+
|
| 125 |
+
# Apply non-maxima suppression
|
| 126 |
+
suppressed = image * mask.float()
|
| 127 |
+
|
| 128 |
+
# Reshape back to original shape
|
| 129 |
+
if len(original_shape) == 2:
|
| 130 |
+
suppressed = suppressed.squeeze(0).squeeze(0)
|
| 131 |
+
mask = mask.squeeze(0).squeeze(0)
|
| 132 |
+
elif len(original_shape) == 3:
|
| 133 |
+
suppressed = suppressed.squeeze(0)
|
| 134 |
+
mask = mask.squeeze(0)
|
| 135 |
+
|
| 136 |
+
if return_mask:
|
| 137 |
+
return suppressed, mask
|
| 138 |
+
return suppressed
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def non_maxima_suppression_with_orientation(
|
| 142 |
+
magnitude: torch.Tensor,
|
| 143 |
+
orientation: torch.Tensor,
|
| 144 |
+
threshold: Optional[float] = None
|
| 145 |
+
) -> torch.Tensor:
|
| 146 |
+
"""
|
| 147 |
+
Perform oriented non-maxima suppression (commonly used in edge detection).
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
magnitude (torch.Tensor): Gradient magnitude tensor of shape (H, W) or (B, C, H, W)
|
| 151 |
+
orientation (torch.Tensor): Gradient orientation tensor (in radians) of same shape
|
| 152 |
+
threshold (float, optional): Minimum magnitude threshold
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
torch.Tensor: Non-maxima suppressed magnitude
|
| 156 |
+
"""
|
| 157 |
+
original_shape = magnitude.shape
|
| 158 |
+
|
| 159 |
+
# Handle different input shapes
|
| 160 |
+
if len(magnitude.shape) == 2:
|
| 161 |
+
magnitude = magnitude.unsqueeze(0).unsqueeze(0)
|
| 162 |
+
orientation = orientation.unsqueeze(0).unsqueeze(0)
|
| 163 |
+
elif len(magnitude.shape) == 3:
|
| 164 |
+
magnitude = magnitude.unsqueeze(0)
|
| 165 |
+
orientation = orientation.unsqueeze(0)
|
| 166 |
+
|
| 167 |
+
batch_size, channels, height, width = magnitude.shape
|
| 168 |
+
device = magnitude.device
|
| 169 |
+
|
| 170 |
+
# Apply threshold if specified
|
| 171 |
+
if threshold is not None:
|
| 172 |
+
magnitude = torch.where(magnitude >= threshold, magnitude, torch.tensor(0.0, device=device))
|
| 173 |
+
|
| 174 |
+
# Convert orientation to degrees and normalize to [0, 180)
|
| 175 |
+
angle = torch.rad2deg(orientation) % 180
|
| 176 |
+
|
| 177 |
+
# Create padded magnitude for neighbor comparison
|
| 178 |
+
mag_padded = F.pad(magnitude, (1, 1, 1, 1), mode='constant', value=0)
|
| 179 |
+
|
| 180 |
+
# Initialize output
|
| 181 |
+
suppressed = torch.zeros_like(magnitude)
|
| 182 |
+
|
| 183 |
+
# Define 8-connectivity neighbors
|
| 184 |
+
for b in range(batch_size):
|
| 185 |
+
for c in range(channels):
|
| 186 |
+
mag = magnitude[b, c]
|
| 187 |
+
ang = angle[b, c]
|
| 188 |
+
mag_pad = mag_padded[b, c]
|
| 189 |
+
|
| 190 |
+
for i in range(1, height + 1):
|
| 191 |
+
for j in range(1, width + 1):
|
| 192 |
+
current_mag = mag_pad[i, j]
|
| 193 |
+
current_angle = ang[i-1, j-1]
|
| 194 |
+
|
| 195 |
+
if current_mag == 0:
|
| 196 |
+
continue
|
| 197 |
+
|
| 198 |
+
# Determine interpolation direction based on angle
|
| 199 |
+
if (0 <= current_angle < 22.5) or (157.5 <= current_angle < 180):
|
| 200 |
+
# Horizontal direction (0°)
|
| 201 |
+
neighbor1 = mag_pad[i, j-1]
|
| 202 |
+
neighbor2 = mag_pad[i, j+1]
|
| 203 |
+
elif 22.5 <= current_angle < 67.5:
|
| 204 |
+
# Diagonal direction (45°)
|
| 205 |
+
neighbor1 = mag_pad[i-1, j+1]
|
| 206 |
+
neighbor2 = mag_pad[i+1, j-1]
|
| 207 |
+
elif 67.5 <= current_angle < 112.5:
|
| 208 |
+
# Vertical direction (90°)
|
| 209 |
+
neighbor1 = mag_pad[i-1, j]
|
| 210 |
+
neighbor2 = mag_pad[i+1, j]
|
| 211 |
+
else: # 112.5 <= current_angle < 157.5
|
| 212 |
+
# Diagonal direction (135°)
|
| 213 |
+
neighbor1 = mag_pad[i-1, j-1]
|
| 214 |
+
neighbor2 = mag_pad[i+1, j+1]
|
| 215 |
+
|
| 216 |
+
# Keep pixel if it's a local maximum
|
| 217 |
+
if current_mag >= neighbor1 and current_mag >= neighbor2:
|
| 218 |
+
suppressed[b, c, i-1, j-1] = current_mag
|
| 219 |
+
|
| 220 |
+
# Reshape back to original shape
|
| 221 |
+
if len(original_shape) == 2:
|
| 222 |
+
suppressed = suppressed.squeeze(0).squeeze(0)
|
| 223 |
+
elif len(original_shape) == 3:
|
| 224 |
+
suppressed = suppressed.squeeze(0)
|
| 225 |
+
|
| 226 |
+
return suppressed
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def adaptive_non_maxima_suppression(
|
| 230 |
+
image: torch.Tensor,
|
| 231 |
+
num_points: int,
|
| 232 |
+
min_distance: int = 5,
|
| 233 |
+
threshold: Optional[float] = None
|
| 234 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 235 |
+
"""
|
| 236 |
+
Adaptive non-maxima suppression that selects a fixed number of strongest points
|
| 237 |
+
while maintaining minimum distance between them.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
image (torch.Tensor): Input tensor of shape (H, W)
|
| 241 |
+
num_points (int): Number of points to select
|
| 242 |
+
min_distance (int): Minimum distance between selected points
|
| 243 |
+
threshold (float, optional): Minimum value threshold
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
Tuple[torch.Tensor, torch.Tensor]: Coordinates (y, x) and values of selected points
|
| 247 |
+
"""
|
| 248 |
+
if len(image.shape) != 2:
|
| 249 |
+
raise ValueError("Input must be a 2D tensor")
|
| 250 |
+
|
| 251 |
+
height, width = image.shape
|
| 252 |
+
device = image.device
|
| 253 |
+
|
| 254 |
+
# Apply threshold if specified
|
| 255 |
+
if threshold is not None:
|
| 256 |
+
image = torch.where(image >= threshold, image, torch.tensor(0.0, device=device))
|
| 257 |
+
|
| 258 |
+
# Find all local maxima using simple NMS
|
| 259 |
+
nms_result = non_maxima_suppression_2d(image, kernel_size=3)
|
| 260 |
+
|
| 261 |
+
# Get coordinates and values of all local maxima
|
| 262 |
+
y_coords, x_coords = torch.nonzero(nms_result > 0, as_tuple=True)
|
| 263 |
+
values = nms_result[y_coords, x_coords]
|
| 264 |
+
|
| 265 |
+
if len(values) == 0:
|
| 266 |
+
return torch.empty((0, 2), device=device), torch.empty(0, device=device)
|
| 267 |
+
|
| 268 |
+
# Sort by strength (descending)
|
| 269 |
+
sorted_indices = torch.argsort(values, descending=True)
|
| 270 |
+
y_coords = y_coords[sorted_indices]
|
| 271 |
+
x_coords = x_coords[sorted_indices]
|
| 272 |
+
values = values[sorted_indices]
|
| 273 |
+
|
| 274 |
+
# Select points with minimum distance constraint
|
| 275 |
+
selected_coords = []
|
| 276 |
+
selected_values = []
|
| 277 |
+
|
| 278 |
+
for i in range(len(values)):
|
| 279 |
+
if len(selected_coords) >= num_points:
|
| 280 |
+
break
|
| 281 |
+
|
| 282 |
+
current_y, current_x = y_coords[i].item(), x_coords[i].item()
|
| 283 |
+
current_val = values[i].item()
|
| 284 |
+
|
| 285 |
+
# Check distance to all previously selected points
|
| 286 |
+
valid = True
|
| 287 |
+
for sel_y, sel_x in selected_coords:
|
| 288 |
+
distance = ((current_y - sel_y) ** 2 + (current_x - sel_x) ** 2) ** 0.5
|
| 289 |
+
if distance < min_distance:
|
| 290 |
+
valid = False
|
| 291 |
+
break
|
| 292 |
+
|
| 293 |
+
if valid:
|
| 294 |
+
selected_coords.append((current_y, current_x))
|
| 295 |
+
selected_values.append(current_val)
|
| 296 |
+
|
| 297 |
+
if selected_coords:
|
| 298 |
+
coords_tensor = torch.tensor(selected_coords, device=device, dtype=torch.float32)
|
| 299 |
+
values_tensor = torch.tensor(selected_values, device=device, dtype=torch.float32)
|
| 300 |
+
else:
|
| 301 |
+
coords_tensor = torch.empty((0, 2), device=device)
|
| 302 |
+
values_tensor = torch.empty(0, device=device)
|
| 303 |
+
|
| 304 |
+
return coords_tensor, values_tensor
|
| 305 |
+
|
util/size_aware_batching.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Size-aware batching utilities for variable-sized seismic images
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import DataLoader, Sampler
|
| 7 |
+
import numpy as np
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
import random
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SizeAwareSampler(Sampler):
|
| 13 |
+
"""
|
| 14 |
+
Groups samples by size and creates batches with images of the same size
|
| 15 |
+
"""
|
| 16 |
+
def __init__(self, dataset, batch_size, get_size_fn=None):
|
| 17 |
+
"""
|
| 18 |
+
Args:
|
| 19 |
+
dataset: PyTorch dataset
|
| 20 |
+
batch_size: batch size for each size group
|
| 21 |
+
get_size_fn: function that takes dataset index and returns (height, width)
|
| 22 |
+
If None, will try to infer from dataset
|
| 23 |
+
"""
|
| 24 |
+
self.dataset = dataset
|
| 25 |
+
self.batch_size = batch_size
|
| 26 |
+
self.get_size_fn = get_size_fn
|
| 27 |
+
|
| 28 |
+
# Group indices by size
|
| 29 |
+
self.size_groups = self._group_by_size()
|
| 30 |
+
|
| 31 |
+
# Create batches
|
| 32 |
+
self.batches = self._create_batches()
|
| 33 |
+
|
| 34 |
+
def _group_by_size(self):
|
| 35 |
+
"""Group dataset indices by image size"""
|
| 36 |
+
size_groups = defaultdict(list)
|
| 37 |
+
|
| 38 |
+
for idx in range(len(self.dataset)):
|
| 39 |
+
if self.get_size_fn:
|
| 40 |
+
size = self.get_size_fn(idx)
|
| 41 |
+
else:
|
| 42 |
+
# Try to get size from dataset item
|
| 43 |
+
sample = self.dataset[idx]
|
| 44 |
+
if isinstance(sample, (tuple, list)):
|
| 45 |
+
# Assume first element is the image tensor
|
| 46 |
+
img_tensor = sample[0]
|
| 47 |
+
else:
|
| 48 |
+
img_tensor = sample
|
| 49 |
+
|
| 50 |
+
# Get size from tensor shape (assuming shape is [C, H, W] or [H, W])
|
| 51 |
+
if len(img_tensor.shape) == 3:
|
| 52 |
+
size = (img_tensor.shape[1], img_tensor.shape[2]) # H, W
|
| 53 |
+
elif len(img_tensor.shape) == 2:
|
| 54 |
+
size = (img_tensor.shape[0], img_tensor.shape[1]) # H, W
|
| 55 |
+
else:
|
| 56 |
+
raise ValueError(f"Unexpected tensor shape: {img_tensor.shape}")
|
| 57 |
+
|
| 58 |
+
size_groups[size].append(idx)
|
| 59 |
+
|
| 60 |
+
return size_groups
|
| 61 |
+
def _create_batches(self, random_size = True):
|
| 62 |
+
"""Create batches from size groups"""
|
| 63 |
+
batches = []
|
| 64 |
+
|
| 65 |
+
for size, indices in self.size_groups.items():
|
| 66 |
+
# Shuffle indices within each size group
|
| 67 |
+
random.shuffle(indices)
|
| 68 |
+
|
| 69 |
+
# Create batches of the specified size
|
| 70 |
+
for i in range(0, len(indices), self.batch_size):
|
| 71 |
+
batch = indices[i:i + self.batch_size]
|
| 72 |
+
batches.append(batch)
|
| 73 |
+
|
| 74 |
+
return batches
|
| 75 |
+
|
| 76 |
+
def __iter__(self):
|
| 77 |
+
# Shuffle the order of batches
|
| 78 |
+
random.shuffle(self.batches)
|
| 79 |
+
for batch in self.batches:
|
| 80 |
+
yield batch
|
| 81 |
+
|
| 82 |
+
def __len__(self):
|
| 83 |
+
return len(self.batches)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class FixedSizeSampler(Sampler):
|
| 87 |
+
"""
|
| 88 |
+
Sampler for datasets where you know the exact 3 size categories
|
| 89 |
+
More efficient than SizeAwareSampler when sizes are known
|
| 90 |
+
"""
|
| 91 |
+
def __init__(self, dataset, batch_size, size_categories):
|
| 92 |
+
"""
|
| 93 |
+
Args:
|
| 94 |
+
dataset: PyTorch dataset
|
| 95 |
+
batch_size: batch size for each size category
|
| 96 |
+
size_categories: list of (height, width) tuples for the 3 categories
|
| 97 |
+
e.g., [(601, 200), (200, 255), (601, 255)]
|
| 98 |
+
"""
|
| 99 |
+
self.dataset = dataset
|
| 100 |
+
self.batch_size = batch_size
|
| 101 |
+
self.size_categories = size_categories
|
| 102 |
+
|
| 103 |
+
# Map indices to size categories
|
| 104 |
+
self.size_to_indices = {size: [] for size in size_categories}
|
| 105 |
+
self._categorize_indices()
|
| 106 |
+
|
| 107 |
+
# Create batches
|
| 108 |
+
self.batches = self._create_batches()
|
| 109 |
+
|
| 110 |
+
def _categorize_indices(self):
|
| 111 |
+
"""Categorize dataset indices by their size"""
|
| 112 |
+
for idx in range(len(self.dataset)):
|
| 113 |
+
sample = self.dataset[idx]
|
| 114 |
+
if isinstance(sample, (tuple, list)):
|
| 115 |
+
img_tensor = sample[0]
|
| 116 |
+
else:
|
| 117 |
+
img_tensor = sample
|
| 118 |
+
|
| 119 |
+
# Get size from tensor
|
| 120 |
+
if len(img_tensor.shape) == 3:
|
| 121 |
+
size = (img_tensor.shape[1], img_tensor.shape[2])
|
| 122 |
+
elif len(img_tensor.shape) == 2:
|
| 123 |
+
size = (img_tensor.shape[0], img_tensor.shape[1])
|
| 124 |
+
else:
|
| 125 |
+
raise ValueError(f"Unexpected tensor shape: {img_tensor.shape}")
|
| 126 |
+
|
| 127 |
+
# Find matching category
|
| 128 |
+
if size in self.size_categories:
|
| 129 |
+
self.size_to_indices[size].append(idx)
|
| 130 |
+
else:
|
| 131 |
+
# Find closest size category (optional)
|
| 132 |
+
closest_size = min(self.size_categories,
|
| 133 |
+
key=lambda cat: abs(cat[0] - size[0]) + abs(cat[1] - size[1]))
|
| 134 |
+
print(f"Warning: Size {size} not in categories, assigning to {closest_size}")
|
| 135 |
+
self.size_to_indices[closest_size].append(idx)
|
| 136 |
+
|
| 137 |
+
def _create_batches(self, random_size = True):
|
| 138 |
+
"""Create batches from size categories"""
|
| 139 |
+
batches = []
|
| 140 |
+
|
| 141 |
+
for size, indices in self.size_to_indices.items():
|
| 142 |
+
if not indices:
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
# Shuffle indices within each size category
|
| 146 |
+
random.shuffle(indices)
|
| 147 |
+
|
| 148 |
+
# Create batches
|
| 149 |
+
for i in range(0, len(indices), self.batch_size):
|
| 150 |
+
batch = indices[i:i + self.batch_size]
|
| 151 |
+
batches.append(batch)
|
| 152 |
+
|
| 153 |
+
return batches
|
| 154 |
+
|
| 155 |
+
def __iter__(self):
|
| 156 |
+
# Shuffle the order of batches across all size categories
|
| 157 |
+
random.shuffle(self.batches)
|
| 158 |
+
for batch in self.batches:
|
| 159 |
+
yield batch
|
| 160 |
+
|
| 161 |
+
def __len__(self):
|
| 162 |
+
return len(self.batches)
|
| 163 |
+
|
| 164 |
+
def get_size_distribution(self):
|
| 165 |
+
"""Get the distribution of samples across size categories"""
|
| 166 |
+
distribution = {}
|
| 167 |
+
for size, indices in self.size_to_indices.items():
|
| 168 |
+
distribution[size] = len(indices)
|
| 169 |
+
return distribution
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def create_size_aware_dataloader(dataset, batch_size=8, size_categories=None,
|
| 173 |
+
num_workers=4, pin_memory=True, **kwargs):
|
| 174 |
+
"""
|
| 175 |
+
Create a DataLoader that batches samples by size
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
dataset: PyTorch dataset
|
| 179 |
+
batch_size: batch size for each size group
|
| 180 |
+
size_categories: list of (height, width) tuples for known size categories
|
| 181 |
+
If None, will auto-detect sizes
|
| 182 |
+
num_workers: number of worker processes
|
| 183 |
+
pin_memory: whether to pin memory
|
| 184 |
+
**kwargs: additional arguments for DataLoader
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
DataLoader with size-aware batching
|
| 188 |
+
"""
|
| 189 |
+
if size_categories:
|
| 190 |
+
sampler = FixedSizeSampler(dataset, batch_size, size_categories)
|
| 191 |
+
else:
|
| 192 |
+
sampler = SizeAwareSampler(dataset, batch_size)
|
| 193 |
+
|
| 194 |
+
# Remove batch_size from kwargs since we're using a custom sampler
|
| 195 |
+
kwargs.pop('batch_size', None)
|
| 196 |
+
kwargs.pop('shuffle', None) # Sampler handles shuffling
|
| 197 |
+
|
| 198 |
+
return DataLoader(
|
| 199 |
+
dataset,
|
| 200 |
+
batch_sampler=sampler,
|
| 201 |
+
num_workers=num_workers,
|
| 202 |
+
pin_memory=pin_memory,
|
| 203 |
+
**kwargs
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
# Custom collate function for same-size batches (no padding needed)
|
| 208 |
+
def same_size_collate_fn(batch):
|
| 209 |
+
"""
|
| 210 |
+
Collate function for batches where all items have the same size
|
| 211 |
+
No padding required since all images in batch are same size
|
| 212 |
+
"""
|
| 213 |
+
if isinstance(batch[0], (tuple, list)):
|
| 214 |
+
# Assuming (image, target) pairs
|
| 215 |
+
images, targets = zip(*batch)
|
| 216 |
+
return torch.stack(images), torch.stack(targets)
|
| 217 |
+
else:
|
| 218 |
+
# Just images
|
| 219 |
+
return torch.stack(batch)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# Utility function to check batch sizes
|
| 224 |
+
def validate_batch_sizes(dataloader, num_batches_to_check=5):
|
| 225 |
+
"""
|
| 226 |
+
Validate that all images in each batch have the same size
|
| 227 |
+
"""
|
| 228 |
+
print("Validating batch sizes...")
|
| 229 |
+
|
| 230 |
+
for i, batch in enumerate(dataloader):
|
| 231 |
+
if i >= num_batches_to_check:
|
| 232 |
+
break
|
| 233 |
+
|
| 234 |
+
if isinstance(batch, (tuple, list)):
|
| 235 |
+
images = batch[0]
|
| 236 |
+
else:
|
| 237 |
+
images = batch
|
| 238 |
+
|
| 239 |
+
batch_size = images.shape[0]
|
| 240 |
+
height = images.shape[2]
|
| 241 |
+
width = images.shape[3]
|
| 242 |
+
|
| 243 |
+
print(f"Batch {i}: {batch_size} images of size {height}x{width}")
|
| 244 |
+
|
| 245 |
+
# Verify all images in batch have same size
|
| 246 |
+
for j in range(batch_size):
|
| 247 |
+
img_h, img_w = images[j].shape[1], images[j].shape[2]
|
| 248 |
+
if img_h != height or img_w != width:
|
| 249 |
+
print(f" WARNING: Image {j} has different size {img_h}x{img_w}")
|
| 250 |
+
|
| 251 |
+
print("Validation complete!")
|
util/skeletonize.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Courtesy of Martin Mentan:
|
| 3 |
+
|
| 4 |
+
Works Cited
|
| 5 |
+
Menten, Martin J., et al. ‘A Skeletonization Algorithm for Gradient-Based Optimization’.
|
| 6 |
+
Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), 2023.
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Skeletonize(torch.nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
Class based on PyTorch's Module class to skeletonize two- or three-dimensional input images
|
| 18 |
+
while being fully compatible with PyTorch's autograd automatic differention engine as proposed in [1].
|
| 19 |
+
|
| 20 |
+
Attributes:
|
| 21 |
+
propabilistic: a Boolean that indicates whether the input image should be binarized using
|
| 22 |
+
the reparametrization trick and straight-through estimator.
|
| 23 |
+
It should always be set to True if non-binary inputs are being provided.
|
| 24 |
+
beta: scale of added logistic noise during the reparametrization trick. If too small, there will not be any learning via
|
| 25 |
+
gradient-based optimization; if too large, the learning is very slow.
|
| 26 |
+
tau: Boltzmann temperature for reparametrization trick.
|
| 27 |
+
simple_point_detection: decides whether simple points should be identified using Boolean characterization of their 26-neighborhood (Boolean) [2]
|
| 28 |
+
or by checking whether the Euler characteristic changes under their deletion (EulerCharacteristic) [3].
|
| 29 |
+
num_iter: number of iterations that each include one end-point check, eight checks for simple points and eight subsequent deletions.
|
| 30 |
+
The number of iterations should be tuned to the type of input image.
|
| 31 |
+
|
| 32 |
+
[1] Martin J. Menten et al. A skeletonization algorithm for gradient-based optimization.
|
| 33 |
+
Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), 2023.
|
| 34 |
+
[2] Gilles Bertrand. A boolean characterization of three- dimensional simple points.
|
| 35 |
+
Pattern recognition letters, 17(2):115-124, 1996.
|
| 36 |
+
[3] Steven Lobregt et al. Three-dimensional skeletonization:principle and algorithm.
|
| 37 |
+
IEEE Transactions on pattern analysis and machine intelligence, 2(1):75-77, 1980.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, probabilistic=True, beta=0.33, tau=1.0, simple_point_detection='Boolean', num_iter=5):
|
| 41 |
+
|
| 42 |
+
super(Skeletonize, self).__init__()
|
| 43 |
+
|
| 44 |
+
self.probabilistic = probabilistic
|
| 45 |
+
self.tau = tau
|
| 46 |
+
self.beta = beta
|
| 47 |
+
|
| 48 |
+
self.num_iter = num_iter
|
| 49 |
+
self.endpoint_check = self._single_neighbor_check
|
| 50 |
+
if simple_point_detection == 'Boolean':
|
| 51 |
+
self.simple_check = self._boolean_simple_check
|
| 52 |
+
elif simple_point_detection == 'EulerCharacteristic':
|
| 53 |
+
self.simple_check = self._euler_characteristic_simple_check
|
| 54 |
+
else:
|
| 55 |
+
raise Exception()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def forward(self, img):
|
| 59 |
+
|
| 60 |
+
img = self._prepare_input(img)
|
| 61 |
+
|
| 62 |
+
if self.probabilistic:
|
| 63 |
+
img = self._stochastic_discretization(img)
|
| 64 |
+
|
| 65 |
+
for current_iter in range(self.num_iter):
|
| 66 |
+
|
| 67 |
+
# At each iteration create a new map of the end-points
|
| 68 |
+
is_endpoint = self.endpoint_check(img)
|
| 69 |
+
|
| 70 |
+
# Sub-iterate through eight different subfields
|
| 71 |
+
x_offsets = [0, 1, 0, 1, 0, 1, 0, 1]
|
| 72 |
+
y_offsets = [0, 0, 1, 1, 0, 0, 1, 1]
|
| 73 |
+
z_offsets = [0, 0, 0, 0, 1, 1, 1, 1]
|
| 74 |
+
|
| 75 |
+
for x_offset, y_offset, z_offset in zip(x_offsets, y_offsets, z_offsets):
|
| 76 |
+
|
| 77 |
+
# At each sub-iteration detect all simple points and delete all simple points that are not end-points
|
| 78 |
+
is_simple = self.simple_check(img[:, :, x_offset:, y_offset:, z_offset:])
|
| 79 |
+
deletion_candidates = is_simple * (1 - is_endpoint[:, :, x_offset::2, y_offset::2, z_offset::2])
|
| 80 |
+
img[:, :, x_offset::2, y_offset::2, z_offset::2] = torch.min(img[:, :, x_offset::2, y_offset::2, z_offset::2].clone(), 1 - deletion_candidates)
|
| 81 |
+
|
| 82 |
+
img = self._prepare_output(img)
|
| 83 |
+
|
| 84 |
+
return img
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _prepare_input(self, img):
|
| 89 |
+
"""
|
| 90 |
+
Function to check that the input image is compatible with the subsequent calculations.
|
| 91 |
+
Only two- and three-dimensional images with values between 0 and 1 are supported.
|
| 92 |
+
If the input image is two-dimensional then it is converted into a three-dimensional one for further processing.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
if img.dim() == 5:
|
| 96 |
+
self.expanded_dims = False
|
| 97 |
+
elif img.dim() == 4:
|
| 98 |
+
self.expanded_dims = True
|
| 99 |
+
img = img.unsqueeze(2)
|
| 100 |
+
else:
|
| 101 |
+
raise Exception("Only two-or three-dimensional images (tensor dimensionality of 4 or 5) are supported as input.")
|
| 102 |
+
|
| 103 |
+
if img.shape[2] == 2 or img.shape[3] == 2 or img.shape[4] == 2 or img.shape[3] == 1 or img.shape[4] == 1:
|
| 104 |
+
raise Exception()
|
| 105 |
+
|
| 106 |
+
if img.min() < 0.0 or img.max() > 1.0:
|
| 107 |
+
raise Exception("Image values must lie between 0 and 1.")
|
| 108 |
+
|
| 109 |
+
img = F.pad(img, (1, 1, 1, 1, 1, 1), value=0)
|
| 110 |
+
|
| 111 |
+
return img
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _stochastic_discretization(self, img):
|
| 115 |
+
"""
|
| 116 |
+
Function to binarize the image so that it can be processed by our skeletonization method.
|
| 117 |
+
In order to remain compatible with backpropagation we utilize the reparameterization trick and a straight-through estimator.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
alpha = (img + 1e-8) / (1.0 - img + 1e-8)
|
| 121 |
+
|
| 122 |
+
uniform_noise = torch.rand_like(img)
|
| 123 |
+
uniform_noise = torch.empty_like(img).uniform_(1e-8, 1 - 1e-8)
|
| 124 |
+
logistic_noise = (torch.log(uniform_noise) - torch.log(1 - uniform_noise))
|
| 125 |
+
|
| 126 |
+
img = torch.sigmoid((torch.log(alpha) + logistic_noise * self.beta) / self.tau)
|
| 127 |
+
img = (img.detach() > 0.5).float() - img.detach() + img
|
| 128 |
+
|
| 129 |
+
return img
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _single_neighbor_check(self, img):
|
| 133 |
+
"""
|
| 134 |
+
Function that characterizes points as endpoints if they have a single neighbor or no neighbor at all.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
img = F.pad(img, (1, 1, 1, 1, 1, 1))
|
| 138 |
+
|
| 139 |
+
# Check that number of ones in twentysix-neighborhood is exactly 0 or 1
|
| 140 |
+
K = torch.tensor([[[1.0, 1.0, 1.0],
|
| 141 |
+
[1.0, 1.0, 1.0],
|
| 142 |
+
[1.0, 1.0, 1.0]],
|
| 143 |
+
[[1.0, 1.0, 1.0],
|
| 144 |
+
[1.0, 0.0, 1.0],
|
| 145 |
+
[1.0, 1.0, 1.0]],
|
| 146 |
+
[[1.0, 1.0, 1.0],
|
| 147 |
+
[1.0, 1.0, 1.0],
|
| 148 |
+
[1.0, 1.0, 1.0]]], device=img.device).view(1, 1, 3, 3, 3)
|
| 149 |
+
|
| 150 |
+
num_twentysix_neighbors = F.conv3d(img, K)
|
| 151 |
+
condition1 = F.hardtanh(-(num_twentysix_neighbors - 2), min_val=0, max_val=1) # 1 or fewer neigbors
|
| 152 |
+
|
| 153 |
+
return condition1
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _boolean_simple_check(self, img):
|
| 157 |
+
"""
|
| 158 |
+
Function that identifies simple points using Boolean conditions introduced by Bertrand et al. [1].
|
| 159 |
+
Each Boolean conditions can be assessed via convolutions with a limited number of pre-defined kernels.
|
| 160 |
+
It total, four conditions are checked. If any one is fulfilled, the point is deemed simple.
|
| 161 |
+
|
| 162 |
+
[1] Gilles Bertrand. A boolean characterization of three- dimensional simple points.
|
| 163 |
+
Pattern recognition letters, 17(2):115-124, 1996.
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
img = F.pad(img, (1, 1, 1, 1, 1, 1), value=0)
|
| 167 |
+
|
| 168 |
+
# Condition 1: number of zeros in the six-neighborhood is exactly 1
|
| 169 |
+
K_N6 = torch.tensor([[[0.0, 0.0, 0.0],
|
| 170 |
+
[0.0, 1.0, 0.0],
|
| 171 |
+
[0.0, 0.0, 0.0]],
|
| 172 |
+
[[0.0, 1.0, 0.0],
|
| 173 |
+
[1.0, 0.0, 1.0],
|
| 174 |
+
[0.0, 1.0, 0.0]],
|
| 175 |
+
[[0.0, 0.0, 0.0],
|
| 176 |
+
[0.0, 1.0, 0.0],
|
| 177 |
+
[0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3)
|
| 178 |
+
|
| 179 |
+
num_six_neighbors = F.conv3d(1 - img, K_N6, stride=2)
|
| 180 |
+
|
| 181 |
+
subcondition1a = F.hardtanh(num_six_neighbors, min_val=0, max_val=1) # 1 or more neighbors
|
| 182 |
+
subcondition1b = F.hardtanh(-(num_six_neighbors - 2), min_val=0, max_val=1) # 1 or fewer neighbors
|
| 183 |
+
|
| 184 |
+
condition1 = subcondition1a * subcondition1b
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# Condition 2: number of ones in twentysix-neighborhood is exactly 1
|
| 188 |
+
K_N26 = torch.tensor([[[1.0, 1.0, 1.0],
|
| 189 |
+
[1.0, 1.0, 1.0],
|
| 190 |
+
[1.0, 1.0, 1.0]],
|
| 191 |
+
[[1.0, 1.0, 1.0],
|
| 192 |
+
[1.0, 0.0, 1.0],
|
| 193 |
+
[1.0, 1.0, 1.0]],
|
| 194 |
+
[[1.0, 1.0, 1.0],
|
| 195 |
+
[1.0, 1.0, 1.0],
|
| 196 |
+
[1.0, 1.0, 1.0]]], device=img.device).view(1, 1, 3, 3, 3)
|
| 197 |
+
|
| 198 |
+
num_twentysix_neighbors = F.conv3d(img, K_N26, stride=2)
|
| 199 |
+
|
| 200 |
+
subcondition2a = F.hardtanh(num_twentysix_neighbors, min_val=0, max_val=1) # 1 or more neighbors
|
| 201 |
+
subcondition2b = F.hardtanh(-(num_twentysix_neighbors - 2), min_val=0, max_val=1) # 1 or fewer neigbors
|
| 202 |
+
|
| 203 |
+
condition2 = subcondition2a * subcondition2b
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# Condition 3: Number of ones in eighteen-neigborhood exactly 1...
|
| 207 |
+
K_N18 = torch.tensor([[[0.0, 1.0, 0.0],
|
| 208 |
+
[1.0, 1.0, 1.0],
|
| 209 |
+
[0.0, 1.0, 0.0]],
|
| 210 |
+
[[1.0, 1.0, 1.0],
|
| 211 |
+
[1.0, 0.0, 1.0],
|
| 212 |
+
[1.0, 1.0, 1.0]],
|
| 213 |
+
[[0.0, 1.0, 0.0],
|
| 214 |
+
[1.0, 1.0, 1.0],
|
| 215 |
+
[0.0, 1.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3)
|
| 216 |
+
|
| 217 |
+
num_eighteen_neighbors = F.conv3d(img, K_N18, stride=2)
|
| 218 |
+
|
| 219 |
+
subcondition3a = F.hardtanh(num_eighteen_neighbors, min_val=0, max_val=1) # 1 or more neighbors
|
| 220 |
+
subcondition3b = F.hardtanh(-(num_eighteen_neighbors - 2), min_val=0, max_val=1) # 1 or fewer neigbors
|
| 221 |
+
|
| 222 |
+
# ... and cell configration B26 does not exist
|
| 223 |
+
K_B26 = torch.tensor([[[1.0, -1.0, 0.0],
|
| 224 |
+
[-1.0, -1.0, 0.0],
|
| 225 |
+
[0.0, 0.0, 0.0]],
|
| 226 |
+
[[-1.0, -1.0, 0.0],
|
| 227 |
+
[-1.0, 0.0, 0.0],
|
| 228 |
+
[0.0, 0.0, 0.0]],
|
| 229 |
+
[[0.0, 0.0, 0.0],
|
| 230 |
+
[0.0, 0.0, 0.0],
|
| 231 |
+
[0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3)
|
| 232 |
+
|
| 233 |
+
B26_1_present = F.relu(F.conv3d(2.0 * img - 1.0, K_B26, stride=2) - 6)
|
| 234 |
+
B26_2_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2]), stride=2) - 6)
|
| 235 |
+
B26_3_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[3]), stride=2) - 6)
|
| 236 |
+
B26_4_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[4]), stride=2) - 6)
|
| 237 |
+
B26_5_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2, 3]), stride=2) - 6)
|
| 238 |
+
B26_6_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2, 4]), stride=2) - 6)
|
| 239 |
+
B26_7_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[3, 4]), stride=2) - 6)
|
| 240 |
+
B26_8_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2, 3, 4]), stride=2) - 6)
|
| 241 |
+
num_B26_cells = B26_1_present + B26_2_present + B26_3_present + B26_4_present + B26_5_present + B26_6_present + B26_7_present + B26_8_present
|
| 242 |
+
|
| 243 |
+
subcondition3c = F.hardtanh(-(num_B26_cells - 1), min_val=0, max_val=1)
|
| 244 |
+
|
| 245 |
+
condition3 = subcondition3a * subcondition3b * subcondition3c
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# Condition 4: cell configuration A6 does not exist...
|
| 249 |
+
K_A6 = torch.tensor([[[0.0, 1.0, 0.0],
|
| 250 |
+
[1.0, -1.0, 1.0],
|
| 251 |
+
[0.0, 1.0, 0.0]],
|
| 252 |
+
[[0.0, 0.0, 0.0],
|
| 253 |
+
[0.0, 0.0, 0.0],
|
| 254 |
+
[0.0, 0.0, 0.0]],
|
| 255 |
+
[[0.0, 0.0, 0.0],
|
| 256 |
+
[0.0, 0.0, 0.0],
|
| 257 |
+
[0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3)
|
| 258 |
+
|
| 259 |
+
A6_1_present = F.relu(F.conv3d(2.0 * img - 1.0, K_A6, stride=2) - 4)
|
| 260 |
+
A6_2_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A6, dims=[2, 3]), stride=2) - 4)
|
| 261 |
+
A6_3_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A6, dims=[2, 4]), stride=2) - 4)
|
| 262 |
+
A6_4_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A6, dims=[2]), stride=2) - 4)
|
| 263 |
+
A6_5_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.flip(K_A6, dims=[2]), dims=[2, 3]), stride=2) - 4)
|
| 264 |
+
A6_6_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.flip(K_A6, dims=[2]), dims=[2, 4]), stride=2) - 4)
|
| 265 |
+
num_A6_cells = A6_1_present + A6_2_present + A6_3_present + A6_4_present + A6_5_present + A6_6_present
|
| 266 |
+
|
| 267 |
+
subcondition4a = F.hardtanh(-(num_A6_cells - 1), min_val=0, max_val=1)
|
| 268 |
+
|
| 269 |
+
# ... and cell configuration B26 does not exist...
|
| 270 |
+
K_B26 = torch.tensor([[[1.0, -1.0, 0.0],
|
| 271 |
+
[-1.0, -1.0, 0.0],
|
| 272 |
+
[0.0, 0.0, 0.0]],
|
| 273 |
+
[[-1.0, -1.0, 0.0],
|
| 274 |
+
[-1.0, 0.0, 0.0],
|
| 275 |
+
[0.0, 0.0, 0.0]],
|
| 276 |
+
[[0.0, 0.0, 0.0],
|
| 277 |
+
[0.0, 0.0, 0.0],
|
| 278 |
+
[0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3)
|
| 279 |
+
|
| 280 |
+
B26_1_present = F.relu(F.conv3d(2.0 * img - 1.0, K_B26, stride=2) - 6)
|
| 281 |
+
B26_2_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2]), stride=2) - 6)
|
| 282 |
+
B26_3_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[3]), stride=2) - 6)
|
| 283 |
+
B26_4_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[4]), stride=2) - 6)
|
| 284 |
+
B26_5_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2, 3]), stride=2) - 6)
|
| 285 |
+
B26_6_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2, 4]), stride=2) - 6)
|
| 286 |
+
B26_7_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[3, 4]), stride=2) - 6)
|
| 287 |
+
B26_8_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_B26, dims=[2, 3, 4]), stride=2) - 6)
|
| 288 |
+
num_B26_cells = B26_1_present + B26_2_present + B26_3_present + B26_4_present + B26_5_present + B26_6_present + B26_7_present + B26_8_present
|
| 289 |
+
|
| 290 |
+
subcondition4b = F.hardtanh(-(num_B26_cells - 1), min_val=0, max_val=1)
|
| 291 |
+
|
| 292 |
+
# ... and cell configuration B18 does not exist...
|
| 293 |
+
K_B18 = torch.tensor([[[0.0, 1.0, 0.0],
|
| 294 |
+
[-1.0, -1.0, -1.0],
|
| 295 |
+
[0.0, 0.0, 0.0]],
|
| 296 |
+
[[-1.0, -1.0, -1.0],
|
| 297 |
+
[-1.0, 0.0, -1.0],
|
| 298 |
+
[0.0, 0.0, 0.0]],
|
| 299 |
+
[[0.0, 0.0, 0.0],
|
| 300 |
+
[0.0, 0.0, 0.0],
|
| 301 |
+
[0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3)
|
| 302 |
+
|
| 303 |
+
B18_1_present = F.relu(F.conv3d(2.0 * img - 1.0, K_B18, stride=2) - 8)
|
| 304 |
+
B18_2_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_B18, dims=[2, 4]), stride=2) - 8)
|
| 305 |
+
B18_3_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_B18, dims=[2, 4], k=2), stride=2) - 8)
|
| 306 |
+
B18_4_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_B18, dims=[2, 4], k=3), stride=2) - 8)
|
| 307 |
+
B18_5_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_B18, dims=[3, 4]), stride=2) - 8)
|
| 308 |
+
B18_6_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_B18, dims=[3, 4]), dims=[2, 4]), stride=2) - 8)
|
| 309 |
+
B18_7_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_B18, dims=[3, 4]), dims=[2, 4], k=2), stride=2) - 8)
|
| 310 |
+
B18_8_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_B18, dims=[3, 4]), dims=[2, 4], k=3), stride=2) - 8)
|
| 311 |
+
B18_9_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_B18, dims=[3, 4], k=2), stride=2) - 8)
|
| 312 |
+
B18_10_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_B18, dims=[3, 4], k=2), dims=[2, 4]), stride=2) - 8)
|
| 313 |
+
B18_11_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_B18, dims=[3, 4], k=2), dims=[2, 4], k=2), stride=2) - 8)
|
| 314 |
+
B18_12_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_B18, dims=[3, 4], k=2), dims=[2, 4], k=3), stride=2) - 8)
|
| 315 |
+
num_B18_cells = B18_1_present + B18_2_present + B18_3_present + B18_4_present + B18_5_present + B18_6_present + B18_7_present + B18_8_present + B18_9_present + B18_10_present + B18_11_present + B18_12_present
|
| 316 |
+
|
| 317 |
+
subcondition4c = F.hardtanh(-(num_B18_cells - 1), min_val=0, max_val=1)
|
| 318 |
+
|
| 319 |
+
# ... and the number of zeros in the six-neighborhood minus the number of A18 cell configurations plus the number of A26 cell configurations is exactly one
|
| 320 |
+
K_N6 = torch.tensor([[[0.0, 0.0, 0.0],
|
| 321 |
+
[0.0, 1.0, 0.0],
|
| 322 |
+
[0.0, 0.0, 0.0]],
|
| 323 |
+
[[0.0, 1.0, 0.0],
|
| 324 |
+
[1.0, 0.0, 1.0],
|
| 325 |
+
[0.0, 1.0, 0.0]],
|
| 326 |
+
[[0.0, 0.0, 0.0],
|
| 327 |
+
[0.0, 1.0, 0.0],
|
| 328 |
+
[0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3)
|
| 329 |
+
|
| 330 |
+
num_six_neighbors = F.conv3d(1-img, K_N6, stride=2)
|
| 331 |
+
|
| 332 |
+
K_A18 = torch.tensor([[[0.0, -1.0, 0.0],
|
| 333 |
+
[0.0, -1.0, 0.0],
|
| 334 |
+
[0.0, 0.0, 0.0]],
|
| 335 |
+
[[0.0, -1.0, 0.0],
|
| 336 |
+
[0.0, 0.0, 0.0],
|
| 337 |
+
[0.0, 0.0, 0.0]],
|
| 338 |
+
[[0.0, 0.0, 0.0],
|
| 339 |
+
[0.0, 0.0, 0.0],
|
| 340 |
+
[0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3)
|
| 341 |
+
|
| 342 |
+
A18_1_present = F.relu(F.conv3d(2.0 * img - 1.0, K_A18, stride=2) - 2)
|
| 343 |
+
A18_2_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A18, dims=[2, 4]), stride=2) - 2)
|
| 344 |
+
A18_3_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A18, dims=[2, 4], k=2), stride=2) - 2)
|
| 345 |
+
A18_4_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A18, dims=[2, 4], k=3), stride=2) - 2)
|
| 346 |
+
A18_5_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A18, dims=[3, 4]), stride=2) - 2)
|
| 347 |
+
A18_6_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_A18, dims=[3, 4]), dims=[2, 4]), stride=2) - 2)
|
| 348 |
+
A18_7_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_A18, dims=[3, 4]), dims=[2, 4], k=2), stride=2) - 2)
|
| 349 |
+
A18_8_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_A18, dims=[3, 4]), dims=[2, 4], k=3), stride=2) - 2)
|
| 350 |
+
A18_9_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(K_A18, dims=[3, 4], k=2), stride=2) - 2)
|
| 351 |
+
A18_10_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_A18, dims=[3, 4], k=2), dims=[2, 4]), stride=2) - 2)
|
| 352 |
+
A18_11_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_A18, dims=[3, 4], k=2), dims=[2, 4], k=2), stride=2) - 2)
|
| 353 |
+
A18_12_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.rot90(torch.rot90(K_A18, dims=[3, 4], k=2), dims=[2, 4], k=3), stride=2) - 2)
|
| 354 |
+
num_A18_cells = A18_1_present + A18_2_present + A18_3_present + A18_4_present + A18_5_present + A18_6_present + A18_7_present + A18_8_present + A18_9_present + A18_10_present + A18_11_present + A18_12_present
|
| 355 |
+
|
| 356 |
+
K_A26 = torch.tensor([[[-1.0, -1.0, 0.0],
|
| 357 |
+
[-1.0, -1.0, 0.0],
|
| 358 |
+
[0.0, 0.0, 0.0]],
|
| 359 |
+
[[-1.0, -1.0, 0.0],
|
| 360 |
+
[-1.0, 0.0, 0.0],
|
| 361 |
+
[0.0, 0.0, 0.0]],
|
| 362 |
+
[[0.0, 0.0, 0.0],
|
| 363 |
+
[0.0, 0.0, 0.0],
|
| 364 |
+
[0.0, 0.0, 0.0]]], device=img.device).view(1, 1, 3, 3, 3)
|
| 365 |
+
|
| 366 |
+
A26_1_present = F.relu(F.conv3d(2.0 * img - 1.0, K_A26, stride=2) - 6)
|
| 367 |
+
A26_2_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[2]), stride=2) - 6)
|
| 368 |
+
A26_3_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[3]), stride=2) - 6)
|
| 369 |
+
A26_4_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[4]), stride=2) - 6)
|
| 370 |
+
A26_5_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[2, 3]), stride=2) - 6)
|
| 371 |
+
A26_6_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[2, 4]), stride=2) - 6)
|
| 372 |
+
A26_7_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[3, 4]), stride=2) - 6)
|
| 373 |
+
A26_8_present = F.relu(F.conv3d(2.0 * img - 1.0, torch.flip(K_A26, dims=[2, 3, 4]), stride=2) - 6)
|
| 374 |
+
num_A26_cells = A26_1_present + A26_2_present + A26_3_present + A26_4_present + A26_5_present + A26_6_present + A26_7_present + A26_8_present
|
| 375 |
+
|
| 376 |
+
subcondition4d = F.hardtanh(num_six_neighbors - num_A18_cells + num_A26_cells, min_val=0, max_val=1) # 1 or more configurations
|
| 377 |
+
subcondition4e = F.hardtanh(-(num_six_neighbors - num_A18_cells + num_A26_cells - 2), min_val=0, max_val=1) # 1 or fewer configurations
|
| 378 |
+
|
| 379 |
+
condition4 = subcondition4a * subcondition4b * subcondition4c * subcondition4d * subcondition4e
|
| 380 |
+
|
| 381 |
+
# If any of the four conditions is fulfilled the point is simple
|
| 382 |
+
combined = torch.cat([condition1, condition2, condition3, condition4], dim=1)
|
| 383 |
+
is_simple = torch.amax(combined, dim=1, keepdim=True)
|
| 384 |
+
|
| 385 |
+
return is_simple
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
# Specifically designed to be used with the eight-subfield iterative scheme from above.
|
| 389 |
+
def _euler_characteristic_simple_check(self, img):
|
| 390 |
+
"""
|
| 391 |
+
Function that identifies simple points by assessing whether the Euler characteristic changes when deleting it [1].
|
| 392 |
+
In order to calculate the Euler characteristic, the amount of vertices, edges, faces and octants are counted using convolutions with pre-defined kernels.
|
| 393 |
+
The function is meant to be used in combination with the subfield-based iterative scheme employed in the forward function.
|
| 394 |
+
|
| 395 |
+
[1] Steven Lobregt et al. Three-dimensional skeletonization:principle and algorithm.
|
| 396 |
+
IEEE Transactions on pattern analysis and machine intelligence, 2(1):75-77, 1980.
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
img = F.pad(img, (1, 1, 1, 1, 1, 1), value=0)
|
| 400 |
+
|
| 401 |
+
# Create masked version of the image where the center of 26-neighborhoods is changed to zero
|
| 402 |
+
mask = torch.ones_like(img)
|
| 403 |
+
mask[:, :, 1::2, 1::2, 1::2] = 0
|
| 404 |
+
masked_img = img.clone() * mask
|
| 405 |
+
|
| 406 |
+
# Count vertices
|
| 407 |
+
vertices = F.relu(-(2.0 * img - 1.0))
|
| 408 |
+
num_vertices = F.avg_pool3d(vertices, (3, 3, 3), stride=2) * 27
|
| 409 |
+
|
| 410 |
+
masked_vertices = F.relu(-(2.0 * masked_img - 1.0))
|
| 411 |
+
num_masked_vertices = F.avg_pool3d(masked_vertices, (3, 3, 3), stride=2) * 27
|
| 412 |
+
|
| 413 |
+
# Count edges
|
| 414 |
+
K_ud_edge = torch.tensor([0.5, 0.5], device=img.device).view(1, 1, 2, 1, 1)
|
| 415 |
+
K_ns_edge = torch.tensor([0.5, 0.5], device=img.device).view(1, 1, 1, 2, 1)
|
| 416 |
+
K_we_edge = torch.tensor([0.5, 0.5], device=img.device).view(1, 1, 1, 1, 2)
|
| 417 |
+
|
| 418 |
+
ud_edges = F.relu(F.conv3d(-(2.0 * img - 1.0), K_ud_edge))
|
| 419 |
+
num_ud_edges = F.avg_pool3d(ud_edges, (2, 3, 3), stride=2) * 18
|
| 420 |
+
ns_edges = F.relu(F.conv3d(-(2.0 * img - 1.0), K_ns_edge))
|
| 421 |
+
num_ns_edges = F.avg_pool3d(ns_edges, (3, 2, 3), stride=2) * 18
|
| 422 |
+
we_edges = F.relu(F.conv3d(-(2.0 * img - 1.0), K_we_edge))
|
| 423 |
+
num_we_edges = F.avg_pool3d(we_edges, (3, 3, 2), stride=2) * 18
|
| 424 |
+
num_edges = num_ud_edges + num_ns_edges + num_we_edges
|
| 425 |
+
|
| 426 |
+
masked_ud_edges = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_ud_edge))
|
| 427 |
+
num_masked_ud_edges = F.avg_pool3d(masked_ud_edges, (2, 3, 3), stride=2) * 18
|
| 428 |
+
masked_ns_edges = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_ns_edge))
|
| 429 |
+
num_masked_ns_edges = F.avg_pool3d(masked_ns_edges, (3, 2, 3), stride=2) * 18
|
| 430 |
+
masked_we_edges = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_we_edge))
|
| 431 |
+
num_masked_we_edges = F.avg_pool3d(masked_we_edges, (3, 3, 2), stride=2) * 18
|
| 432 |
+
num_masked_edges = num_masked_ud_edges + num_masked_ns_edges + num_masked_we_edges
|
| 433 |
+
|
| 434 |
+
# Count faces
|
| 435 |
+
K_ud_face = torch.tensor([[0.25, 0.25], [0.25, 0.25]], device=img.device).view(1, 1, 1, 2, 2)
|
| 436 |
+
K_ns_face = torch.tensor([[0.25, 0.25], [0.25, 0.25]], device=img.device).view(1, 1, 2, 1, 2)
|
| 437 |
+
K_we_face = torch.tensor([[0.25, 0.25], [0.25, 0.25]], device=img.device).view(1, 1, 2, 2, 1)
|
| 438 |
+
|
| 439 |
+
ud_faces = F.relu(F.conv3d(-(2.0 * img - 1.0), K_ud_face) - 0.5) * 2
|
| 440 |
+
num_ud_faces = F.avg_pool3d(ud_faces, (3, 2, 2), stride=2) * 12
|
| 441 |
+
ns_faces = F.relu(F.conv3d(-(2.0 * img - 1.0), K_ns_face) - 0.5) * 2
|
| 442 |
+
num_ns_faces = F.avg_pool3d(ns_faces, (2, 3, 2), stride=2) * 12
|
| 443 |
+
we_faces = F.relu(F.conv3d(-(2.0 * img - 1.0), K_we_face) - 0.5) * 2
|
| 444 |
+
num_we_faces = F.avg_pool3d(we_faces, (2, 2, 3), stride=2) * 12
|
| 445 |
+
num_faces = num_ud_faces + num_ns_faces + num_we_faces
|
| 446 |
+
|
| 447 |
+
masked_ud_faces = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_ud_face) - 0.5) * 2
|
| 448 |
+
num_masked_ud_faces = F.avg_pool3d(masked_ud_faces, (3, 2, 2), stride=2) * 12
|
| 449 |
+
masked_ns_faces = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_ns_face) - 0.5) * 2
|
| 450 |
+
num_masked_ns_faces = F.avg_pool3d(masked_ns_faces, (2, 3, 2), stride=2) * 12
|
| 451 |
+
masked_we_faces = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_we_face) - 0.5) * 2
|
| 452 |
+
num_masked_we_faces = F.avg_pool3d(masked_we_faces, (2, 2, 3), stride=2) * 12
|
| 453 |
+
num_masked_faces = num_masked_ud_faces + num_masked_ns_faces + num_masked_we_faces
|
| 454 |
+
|
| 455 |
+
# Count octants
|
| 456 |
+
K_octants = torch.tensor([[[0.125, 0.125], [0.125, 0.125]], [[0.125, 0.125], [0.125, 0.125]]], device=img.device).view(1, 1, 2, 2, 2)
|
| 457 |
+
|
| 458 |
+
octants = F.relu(F.conv3d(-(2.0 * img - 1.0), K_octants) - 0.75) * 4
|
| 459 |
+
num_octants = F.avg_pool3d(octants, (2, 2, 2), stride=2) * 8
|
| 460 |
+
|
| 461 |
+
masked_octants = F.relu(F.conv3d(-(2.0 * masked_img - 1.0), K_octants) - 0.75) * 4
|
| 462 |
+
num_masked_octants = F.avg_pool3d(masked_octants, (2, 2, 2), stride=2) * 8
|
| 463 |
+
|
| 464 |
+
# Combined number of vertices, edges, faces and octants to calculate the euler characteristic
|
| 465 |
+
euler_characteristic = num_vertices - num_edges + num_faces - num_octants
|
| 466 |
+
masked_euler_characteristic = num_masked_vertices - num_masked_edges + num_masked_faces - num_masked_octants
|
| 467 |
+
|
| 468 |
+
# If the Euler characteristic is unchanged after switching a point from 1 to 0 this indicates that the point is simple
|
| 469 |
+
euler_change = F.hardtanh(torch.abs(masked_euler_characteristic - euler_characteristic), min_val=0, max_val=1)
|
| 470 |
+
is_simple = 1 - euler_change
|
| 471 |
+
is_simple = (is_simple.detach() > 0.5).float() - is_simple.detach() + is_simple
|
| 472 |
+
|
| 473 |
+
return is_simple
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def _prepare_output(self, img):
|
| 477 |
+
"""
|
| 478 |
+
Function that removes the padding and dimensions added by _prepare_input function.
|
| 479 |
+
"""
|
| 480 |
+
|
| 481 |
+
img = img[:, :, 1:-1, 1:-1, 1:-1]
|
| 482 |
+
|
| 483 |
+
if self.expanded_dims:
|
| 484 |
+
img = torch.squeeze(img, dim=2)
|
| 485 |
+
|
| 486 |
+
return img
|
util/tools.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Author: Jintao Li
|
| 3 |
+
Date: 2022-05-30 16:42:14
|
| 4 |
+
LastEditors: Jintao Li
|
| 5 |
+
LastEditTime: 2022-07-11 23:05:53
|
| 6 |
+
2022 by CIG.
|
| 7 |
+
'''
|
| 8 |
+
|
| 9 |
+
import os, shutil
|
| 10 |
+
import yaml, argparse
|
| 11 |
+
from sklearn.metrics import confusion_matrix
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def accuracy(output, target):
|
| 17 |
+
'''
|
| 18 |
+
output: [N, num_classes, ...], torch.float
|
| 19 |
+
target: [N, ...], torch.int
|
| 20 |
+
'''
|
| 21 |
+
output = output.argmax(dim=1).flatten().detach().cpu().numpy()
|
| 22 |
+
target = target.flatten().detach().cpu().numpy()
|
| 23 |
+
return pixel_acc(output, target), _miou(output, target)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def pixel_acc(output, target):
|
| 27 |
+
r"""
|
| 28 |
+
计算像素准确率 (Pixel Accuracy, PA)
|
| 29 |
+
$$ PA = \frac{\sum_{i=0}^k p_{ii}}
|
| 30 |
+
{\sum_{i=0}^k \sum_{j=0}^k p_{ij}} $$ and
|
| 31 |
+
$n_class = k+1$
|
| 32 |
+
Parameters:
|
| 33 |
+
-----------
|
| 34 |
+
shape: [N, ], (use flatten() function)
|
| 35 |
+
return:
|
| 36 |
+
----------
|
| 37 |
+
- PA
|
| 38 |
+
"""
|
| 39 |
+
assert output.shape == target.shape, "shapes must be same"
|
| 40 |
+
cm = confusion_matrix(target, output)
|
| 41 |
+
return np.diag(cm).sum() / cm.sum()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _miou(output, target):
|
| 45 |
+
r"""
|
| 46 |
+
计算均值交并比 MIoU (Mean Intersection over Union)
|
| 47 |
+
$$ MIoU = \frac{1}{k+1} \sum_{i=0}^k \frac{p_{ii}}
|
| 48 |
+
{\sum_{j=0}^k p_{ij} + \sum_{j=0}^k p_{ji} - p_{ii}} $$
|
| 49 |
+
Parameters:
|
| 50 |
+
output, target: [N, ]
|
| 51 |
+
return:
|
| 52 |
+
MIoU
|
| 53 |
+
"""
|
| 54 |
+
assert output.shape == target.shape, "shapes must be same"
|
| 55 |
+
cm = confusion_matrix(target, output)
|
| 56 |
+
intersection = np.diag(cm)
|
| 57 |
+
union = np.sum(cm, 1) + np.sum(cm, 0) - np.diag(cm)
|
| 58 |
+
iou = intersection / union
|
| 59 |
+
miou = np.nanmean(iou)
|
| 60 |
+
|
| 61 |
+
return miou
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def yaml_config_hook(config_file: str) -> argparse.Namespace:
|
| 65 |
+
"""
|
| 66 |
+
加载yaml文件里面的参数配置, 并生成argparse形式的参数集合
|
| 67 |
+
"""
|
| 68 |
+
with open(config_file) as f:
|
| 69 |
+
cfg = yaml.safe_load(f)
|
| 70 |
+
for d in cfg.get("defaults", []):
|
| 71 |
+
config_dir, cf = d.popitem()
|
| 72 |
+
cf = os.path.join(os.path.dirname(config_file), config_dir,
|
| 73 |
+
cf + ".yaml")
|
| 74 |
+
with open(cf) as f:
|
| 75 |
+
l = yaml.safe_load(f)
|
| 76 |
+
cfg.update(l)
|
| 77 |
+
|
| 78 |
+
if "defaults" in cfg.keys():
|
| 79 |
+
del cfg["defaults"]
|
| 80 |
+
|
| 81 |
+
parser = argparse.ArgumentParser()
|
| 82 |
+
for k, v in cfg.items():
|
| 83 |
+
parser.add_argument(f"--{k}", default=v, type=type(v))
|
| 84 |
+
args = parser.parse_args()
|
| 85 |
+
|
| 86 |
+
return args
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def backup_code(work_dir, back_dir, exceptions=[], include=[]):
|
| 90 |
+
r"""
|
| 91 |
+
备份本次运行的代码到指定目录下, 并排除某些文件和目录
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
work_dir: 工作目录, i.e. 需要备份的代码
|
| 95 |
+
back_dir: 目标目录.备份代码放置的目录
|
| 96 |
+
exception (list): 被排除的目录和以指定后缀结尾的文件, 默认的有
|
| 97 |
+
["__pycache__", ".pyc", ".dat", "backup", ".vscode"]
|
| 98 |
+
include (list): 某些必须被备份的文件,该文件可能在exception里面
|
| 99 |
+
"""
|
| 100 |
+
_exp = [
|
| 101 |
+
"*__pycache__*", "*.pyc", "*.dat", "backup", ".vscode", "*.log",
|
| 102 |
+
"*log*"
|
| 103 |
+
]
|
| 104 |
+
exceptions = exceptions + _exp
|
| 105 |
+
|
| 106 |
+
# if not os.path.exists(back_dir):
|
| 107 |
+
os.makedirs(back_dir, exist_ok=True)
|
| 108 |
+
|
| 109 |
+
shutil.copytree(work_dir,
|
| 110 |
+
back_dir + 'code/',
|
| 111 |
+
ignore=shutil.ignore_patterns(*exceptions),
|
| 112 |
+
dirs_exist_ok=True)
|
| 113 |
+
|
| 114 |
+
for f in include:
|
| 115 |
+
shutil.copyfile(os.path.join(work_dir, f),
|
| 116 |
+
os.path.join(back_dir + 'code', f))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def list_files(path, full=False):
|
| 120 |
+
r"""
|
| 121 |
+
递归列出目录下所有的文件,包括子目录下的文件
|
| 122 |
+
"""
|
| 123 |
+
out = []
|
| 124 |
+
for f in os.listdir(path):
|
| 125 |
+
fname = os.path.join(path, f)
|
| 126 |
+
if os.path.isdir(fname):
|
| 127 |
+
fname = list_files(fname)
|
| 128 |
+
out += [os.path.join(f, i) for i in fname]
|
| 129 |
+
else:
|
| 130 |
+
out.append(f)
|
| 131 |
+
if full:
|
| 132 |
+
out = [os.path.join(path, i) for i in out]
|
| 133 |
+
return out
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if __name__ == "__main__":
|
| 137 |
+
output = torch.randn(4, 2, 6, 6)
|
| 138 |
+
target = torch.randn(4, 2, 6, 6)
|
| 139 |
+
# output = output.cuda()
|
| 140 |
+
# target = target.cuda()
|
| 141 |
+
target = target.argmax(1)
|
| 142 |
+
|
| 143 |
+
accuracy(output, target)
|
util/variable_pos_embed.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# Variable size position embedding utils for handling different image dimensions
|
| 8 |
+
# --------------------------------------------------------
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_2d_sincos_pos_embed_variable(embed_dim, grid_h, grid_w, cls_token=False):
|
| 16 |
+
"""
|
| 17 |
+
Create 2D sine-cosine position embeddings for variable grid sizes
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
embed_dim: embedding dimension
|
| 21 |
+
grid_h: height of the grid (number of patches in height)
|
| 22 |
+
grid_w: width of the grid (number of patches in width)
|
| 23 |
+
cls_token: whether to include class token
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
pos_embed: [grid_h*grid_w, embed_dim] or [1+grid_h*grid_w, embed_dim] (w/ or w/o cls_token)
|
| 27 |
+
"""
|
| 28 |
+
grid_h_coords = np.arange(grid_h, dtype=np.float32)
|
| 29 |
+
grid_w_coords = np.arange(grid_w, dtype=np.float32)
|
| 30 |
+
grid = np.meshgrid(grid_w_coords, grid_h_coords) # here w goes first
|
| 31 |
+
grid = np.stack(grid, axis=0)
|
| 32 |
+
|
| 33 |
+
grid = grid.reshape([2, 1, grid_h, grid_w])
|
| 34 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 35 |
+
if cls_token:
|
| 36 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 37 |
+
return pos_embed
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 41 |
+
assert embed_dim % 2 == 0
|
| 42 |
+
|
| 43 |
+
# use half of dimensions to encode grid_h
|
| 44 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 45 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 46 |
+
|
| 47 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 48 |
+
return emb
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 52 |
+
"""
|
| 53 |
+
embed_dim: output dimension for each position
|
| 54 |
+
pos: a list of positions to be encoded: size (M,)
|
| 55 |
+
out: (M, D)
|
| 56 |
+
"""
|
| 57 |
+
assert embed_dim % 2 == 0
|
| 58 |
+
omega = np.arange(embed_dim // 2, dtype=np.float)
|
| 59 |
+
omega /= embed_dim / 2.
|
| 60 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 61 |
+
|
| 62 |
+
pos = pos.reshape(-1) # (M,)
|
| 63 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 64 |
+
|
| 65 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 66 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 67 |
+
|
| 68 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 69 |
+
return emb
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def interpolate_pos_embed_variable(original_pos_embed, target_h, target_w, cls_token=True):
|
| 73 |
+
"""
|
| 74 |
+
Interpolate position embeddings for arbitrary target sizes
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
original_pos_embed: original positional embeddings [1, N, D]
|
| 78 |
+
target_h: target height in patches
|
| 79 |
+
target_w: target width in patches
|
| 80 |
+
cls_token: whether the first token is a class token
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
interpolated_pos_embed: [1, target_h*target_w + cls_token, D]
|
| 84 |
+
"""
|
| 85 |
+
embed_dim = original_pos_embed.shape[-1]
|
| 86 |
+
|
| 87 |
+
if cls_token:
|
| 88 |
+
class_pos_embed = original_pos_embed[:, 0:1] # [1, 1, D]
|
| 89 |
+
patch_pos_embed = original_pos_embed[:, 1:] # [1, N-1, D]
|
| 90 |
+
orig_num_patches = patch_pos_embed.shape[1]
|
| 91 |
+
else:
|
| 92 |
+
class_pos_embed = None
|
| 93 |
+
patch_pos_embed = original_pos_embed
|
| 94 |
+
orig_num_patches = patch_pos_embed.shape[1]
|
| 95 |
+
|
| 96 |
+
# Determine original grid size (assume square for original)
|
| 97 |
+
orig_h = orig_w = int(np.sqrt(orig_num_patches))
|
| 98 |
+
|
| 99 |
+
if orig_h * orig_w != orig_num_patches:
|
| 100 |
+
raise ValueError(f"Original number of patches {orig_num_patches} is not a perfect square")
|
| 101 |
+
|
| 102 |
+
# Reshape to spatial dimensions
|
| 103 |
+
patch_pos_embed = patch_pos_embed.reshape(1, orig_h, orig_w, embed_dim)
|
| 104 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) # [1, D, orig_h, orig_w]
|
| 105 |
+
|
| 106 |
+
# Interpolate to target size
|
| 107 |
+
patch_pos_embed = F.interpolate(
|
| 108 |
+
patch_pos_embed,
|
| 109 |
+
size=(target_h, target_w),
|
| 110 |
+
mode='bicubic',
|
| 111 |
+
align_corners=False
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Reshape back to token sequence
|
| 115 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1) # [1, target_h, target_w, D]
|
| 116 |
+
patch_pos_embed = patch_pos_embed.flatten(1, 2) # [1, target_h*target_w, D]
|
| 117 |
+
|
| 118 |
+
if cls_token:
|
| 119 |
+
new_pos_embed = torch.cat([class_pos_embed, patch_pos_embed], dim=1)
|
| 120 |
+
else:
|
| 121 |
+
new_pos_embed = patch_pos_embed
|
| 122 |
+
|
| 123 |
+
return new_pos_embed
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def create_variable_pos_embed(embed_dim, height_patches, width_patches, cls_token=True):
|
| 127 |
+
"""
|
| 128 |
+
Create positional embeddings for specific patch grid dimensions
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
embed_dim: embedding dimension
|
| 132 |
+
height_patches: number of patches in height
|
| 133 |
+
width_patches: number of patches in width
|
| 134 |
+
cls_token: whether to include class token
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
pos_embed: positional embeddings tensor
|
| 138 |
+
"""
|
| 139 |
+
pos_embed_np = get_2d_sincos_pos_embed_variable(
|
| 140 |
+
embed_dim, height_patches, width_patches, cls_token=cls_token
|
| 141 |
+
)
|
| 142 |
+
pos_embed = torch.from_numpy(pos_embed_np).float().unsqueeze(0)
|
| 143 |
+
return pos_embed
|