# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # References: # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm # DeiT: https://github.com/facebookresearch/deit # -------------------------------------------------------- from functools import partial import torch import torch.nn as nn import torch.nn.functional as F import timm.models.vision_transformer import numpy as np from util.msssim import MSSSIM from util.pos_embed import get_2d_sincos_pos_embed from util.variable_pos_embed import interpolate_pos_embed_variable class FlexiblePatchEmbed(nn.Module): """ 2D Image to Patch Embedding that handles variable input sizes """ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, bias=True): super().__init__() self.img_size = img_size self.patch_size = patch_size self.in_chans = in_chans self.embed_dim = embed_dim self.num_patches = (img_size // patch_size) ** 2 # default number of patches self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) def forward(self, x): B, C, H, W = x.shape # Calculate number of patches dynamically self.num_patches = (H // self.patch_size) * (W // self.patch_size) x = self.proj(x).flatten(2).transpose(1, 2) # BCHW -> BNC return x class VisionTransformer(timm.models.vision_transformer.VisionTransformer): """ Vision Transformer with support for global average pooling """ def __init__(self, global_pool=False,**kwargs): super(VisionTransformer, self).__init__(**kwargs) self.global_pool = global_pool self.decoder = DecoderCup(in_channels=[self.embed_dim,256,128,64]) self.segmentation_head = SegmentationHead( in_channels=64, out_channels=self.num_classes, kernel_size=1 ) if self.global_pool: norm_layer = kwargs['norm_layer'] embed_dim = kwargs['embed_dim'] self.fc_norm = norm_layer(embed_dim) del self.norm # remove the original norm def interpolate_pos_encoding(self, x, h, w): """ Interpolate positional embeddings for arbitrary input sizes """ npatch = x.shape[1] - 1 # subtract 1 for cls token N = self.pos_embed.shape[1] - 1 # original number of patches if npatch == N and h == w: return self.pos_embed # Use the new variable position embedding utility return interpolate_pos_embed_variable(self.pos_embed, h, w, cls_token=True) def generate_mask(self,input_tensor, ratio): mask = torch.zeros_like(input_tensor) indices = torch.randperm(mask.size(3)//16)[:int(mask.size(3)//16 * ratio)] sorted_indices = torch.sort(indices)[0] for i in range(0, len(sorted_indices)): mask[:, :, :, sorted_indices[i]*16:(sorted_indices[i]+1)*16] = 1 return mask def forward_features(self, x): B,C,H,W = x.shape # Handle padding for non-16-divisible images patch_size = self.patch_embed.patch_size pad_h = (patch_size - H % patch_size) % patch_size pad_w = (patch_size - W % patch_size) % patch_size if pad_h > 0 or pad_w > 0: x = F.pad(x, (0, pad_w, 0, pad_h), mode='reflect') H_padded, W_padded = H + pad_h, W + pad_w else: H_padded, W_padded = H, W img = x x = self.patch_embed(x) _H, _W = H_padded // patch_size, W_padded // patch_size # Add class token cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) # Add interpolated positional embeddings pos_embed = self.interpolate_pos_encoding(x, _H, _W) x = x + pos_embed x = self.pos_drop(x) for blk in self.blocks: x = blk(x) x = self.norm(x) x = self.decoder(x[:, 1:, :], img) x = self.segmentation_head(x) return x def forward(self, x): x = self.forward_features(x) return x def inference(self, x): x = self.forward_features(x) x = F.softmax(x, dim=1) return x class Conv2dReLU(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, padding=0, stride=1, use_batchnorm=True, ): conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=not (use_batchnorm), ) relu = nn.ReLU(inplace=True) bn = nn.BatchNorm2d(out_channels) super(Conv2dReLU, self).__init__(conv, bn, relu) class DecoderBlock(nn.Module): def __init__( self, in_channels, out_channels, skip_channels=0, use_batchnorm=True, ): super().__init__() self.conv1 = Conv2dReLU( in_channels + skip_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.conv2 = Conv2dReLU( out_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=use_batchnorm, ) self.up = nn.UpsamplingBilinear2d(scale_factor=2) def forward(self, x, skip=None): x = self.up(x) if skip is not None: x = torch.cat([x, skip], dim=1) x = self.conv1(x) x = self.conv2(x) return x class SegmentationHead(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=1, upsampling=1): conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=0) upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() super().__init__(conv2d, upsampling) class DecoderCup(nn.Module): def __init__(self,in_channels=[1024,256,128,64]): super().__init__() head_channels = 512 self.conv_more = Conv2dReLU( 1, 32, kernel_size=3, padding=1, use_batchnorm=True, ) skip_channels=[0,0,0,32] out_channels=[256,128,64,64] blocks = [ DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels) ] self.blocks = nn.ModuleList(blocks) def forward(self, hidden_states, img, features=None): B, n_patch, hidden = hidden_states.size() # reshape from (B, n_patch, hidden) to (B, h, w, hidden) h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch)) x = hidden_states.permute(0, 2, 1) x = x.contiguous().view(B, hidden, h, w) skip_channels=[None,None,None,self.conv_more(img)] for i, decoder_block in enumerate(self.blocks): x = decoder_block(x, skip=skip_channels[i]) return x def forward_loss(imgs, pred): """ imgs: [N, 3, H, W] pred: [N, L, p*p*3] mask: [N, L], 0 is keep, 1 is remove, """ loss1f = torch.nn.MSELoss() loss1 = loss1f(imgs, pred) loss2f = MSSSIM() loss2 = loss2f(imgs, pred) a = 0.5 loss = (1-a)*loss1+a*loss2 return loss def weighted_cross_entropy(pred, target): """ Compute the weighted cross entropy loss. NEED VERIFICATION """ # Function to compute weighted cross entropy loss # target: [batch, channel, s, s] # pred: [batch, channel, s, s] #print('pred shape ', pred.shape) #print('target shape ', target.shape) #print('--------------') #print('sums of pred', torch.sum(pred)) #print('sums of target', torch.sum(target)) # beta is the fraction of non-fault pixels in the target (i.e the zeroes in the target) beta = torch.mean(target) # fraction of fault pixels beta = 1 - beta # fraction of non-fault pixels beta = torch.clamp(beta, min=0.01, max=0.99) # avoid division by zero #print('beta', beta) # Compute the weighted cross entropy loss loss = -(beta * target * torch.log(pred + 1e-8) + (1-beta) * (1 - target) * torch.log(1 - pred + 1e-8)) return torch.mean(loss) def mae_vit_small_patch16(**kwargs): model = VisionTransformer( patch_size=16, embed_dim=768, depth=6, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) # Replace with flexible patch embedding model.patch_embed = FlexiblePatchEmbed( img_size=kwargs.get('img_size', 224), patch_size=16, in_chans=kwargs.get('in_chans', 3), embed_dim=768 ) return model def vit_base_patch16(**kwargs): model = VisionTransformer( patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) # Replace with flexible patch embedding model.patch_embed = FlexiblePatchEmbed( img_size=kwargs.get('img_size', 224), patch_size=16, in_chans=kwargs.get('in_chans', 3), embed_dim=768 ) return model def vit_large_patch16(**kwargs): model = VisionTransformer( patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) # Replace with flexible patch embedding model.patch_embed = FlexiblePatchEmbed( img_size=kwargs.get('img_size', 224), patch_size=16, in_chans=kwargs.get('in_chans', 3), embed_dim=1024 ) return model def vit_huge_patch14(**kwargs): model = VisionTransformer( patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) # Replace with flexible patch embedding model.patch_embed = FlexiblePatchEmbed( img_size=kwargs.get('img_size', 224), patch_size=14, in_chans=kwargs.get('in_chans', 3), embed_dim=1280 ) return model