Spaces:
Running
Running
File size: 3,631 Bytes
4057a1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class INP_Former(nn.Module):
def __init__(
self,
encoder,
bottleneck,
aggregation,
decoder,
target_layers =[2, 3, 4, 5, 6, 7, 8, 9],
fuse_layer_encoder =[[0, 1, 2, 3, 4, 5, 6, 7]],
fuse_layer_decoder =[[0, 1, 2, 3, 4, 5, 6, 7]],
remove_class_token=False,
encoder_require_grad_layer=[],
prototype_token=None,
) -> None:
super(INP_Former, self).__init__()
self.encoder = encoder
self.bottleneck = bottleneck
self.aggregation = aggregation
self.decoder = decoder
self.target_layers = target_layers
self.fuse_layer_encoder = fuse_layer_encoder
self.fuse_layer_decoder = fuse_layer_decoder
self.remove_class_token = remove_class_token
self.encoder_require_grad_layer = encoder_require_grad_layer
self.prototype_token = prototype_token[0]
if not hasattr(self.encoder, 'num_register_tokens'):
self.encoder.num_register_tokens = 0
def gather_loss(self, query, keys):
self.distribution = 1. - F.cosine_similarity(query.unsqueeze(2), keys.unsqueeze(1), dim=-1)
self.distance, self.cluster_index = torch.min(self.distribution, dim=2)
gather_loss = self.distance.mean()
return gather_loss
def forward(self, x):
x = self.encoder.prepare_tokens(x)
B, L, _ = x.shape
en_list = []
for i, blk in enumerate(self.encoder.blocks):
if i <= self.target_layers[-1]:
if i in self.encoder_require_grad_layer:
x = blk(x)
else:
with torch.no_grad():
x = blk(x)
else:
continue
if i in self.target_layers:
en_list.append(x)
side = int(math.sqrt(en_list[0].shape[1] - 1 - self.encoder.num_register_tokens))
if self.remove_class_token:
en_list = [e[:, 1 + self.encoder.num_register_tokens:, :] for e in en_list]
x = self.fuse_feature(en_list)
agg_prototype = self.prototype_token
for i, blk in enumerate(self.aggregation):
agg_prototype = blk(agg_prototype.unsqueeze(0).repeat((B, 1, 1)), x)
g_loss = self.gather_loss(x, agg_prototype)
for i, blk in enumerate(self.bottleneck):
x = blk(x)
de_list = []
for i, blk in enumerate(self.decoder):
x = blk(x, agg_prototype)
de_list.append(x)
de_list = de_list[::-1]
en = [self.fuse_feature([en_list[idx] for idx in idxs]) for idxs in self.fuse_layer_encoder]
de = [self.fuse_feature([de_list[idx] for idx in idxs]) for idxs in self.fuse_layer_decoder]
if not self.remove_class_token: # class tokens have not been removed above
en = [e[:, 1 + self.encoder.num_register_tokens:, :] for e in en]
de = [d[:, 1 + self.encoder.num_register_tokens:, :] for d in de]
en = [e.permute(0, 2, 1).reshape([x.shape[0], -1, side, side]).contiguous() for e in en]
de = [d.permute(0, 2, 1).reshape([x.shape[0], -1, side, side]).contiguous() for d in de]
return en, de, g_loss
def fuse_feature(self, feat_list):
return torch.stack(feat_list, dim=1).mean(dim=1)
|