| | from typing import Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from einops import rearrange |
| |
|
| |
|
| | class MainModel(nn.Module): |
| | def __init__( |
| | self, |
| | encoder, |
| | decoder, |
| | fc_dim: int, |
| | volume_block_idx: int, |
| | share_embed_head: bool, |
| | pre_filter=None, |
| | use_gem: bool = False, |
| | gem_coef: Optional[float] = None, |
| | use_gsm: bool = False, |
| | map_portion: float = 0, |
| | otsu_sel: bool = False, |
| | otsu_portion: float = 1, |
| | ): |
| | super().__init__() |
| | self.encoder = encoder |
| | self.decoder = decoder |
| | self.use_gem = use_gem |
| | self.gem_coef = gem_coef |
| | self.use_gsm = use_gsm |
| | self.map_portion = map_portion |
| | assert self.map_portion <= 0.5, "Map_portion must be less than 0.5" |
| | self.otsu_sel = otsu_sel |
| | self.otsu_portion = otsu_portion |
| |
|
| | self.volume_block_idx = volume_block_idx |
| | volume_in_channel = int(fc_dim * (2 ** (self.volume_block_idx - 3))) |
| | volume_out_channel = volume_in_channel // 2 |
| |
|
| | self.scale = volume_out_channel**0.5 |
| | self.share_embed_head = share_embed_head |
| | self.proj_head1 = nn.Sequential( |
| | nn.Conv2d( |
| | volume_in_channel, volume_in_channel, kernel_size=1, stride=1, padding=0 |
| | ), |
| | nn.LeakyReLU(), |
| | nn.Conv2d( |
| | volume_in_channel, |
| | volume_out_channel, |
| | kernel_size=1, |
| | stride=1, |
| | padding=0, |
| | ), |
| | ) |
| | if not share_embed_head: |
| | self.proj_head2 = nn.Sequential( |
| | nn.Conv2d( |
| | volume_in_channel, |
| | volume_in_channel, |
| | kernel_size=1, |
| | stride=1, |
| | padding=0, |
| | ), |
| | nn.LeakyReLU(), |
| | nn.Conv2d( |
| | volume_in_channel, |
| | volume_out_channel, |
| | kernel_size=1, |
| | stride=1, |
| | padding=0, |
| | ), |
| | ) |
| |
|
| | self.pre_filter = pre_filter |
| |
|
| | def forward(self, image, seg_size=None): |
| | """ |
| | for output maps, the return value is the raw logits |
| | for consistency volume, the return value is the value after sigmoid |
| | """ |
| | bs = image.shape[0] |
| | if self.pre_filter is not None: |
| | image = self.pre_filter(image) |
| |
|
| | |
| | encoder_feature = self.encoder(image, return_feature_maps=True) |
| | output_map = self.decoder(encoder_feature, segSize=seg_size) |
| | output_map = output_map.sigmoid() |
| | |
| |
|
| | |
| | if self.use_gem: |
| | mh, mw = output_map.shape[-2:] |
| | image_pred = output_map.flatten(1) |
| | image_pred = torch.linalg.norm(image_pred, ord=self.gem_coef, dim=1) |
| | image_pred = image_pred / (mh * mw) |
| | elif self.use_gsm: |
| | image_pred = output_map.flatten(1) |
| | weight = project_onto_l1_ball(image_pred, 1.0) |
| | image_pred = (image_pred * weight).sum(1) |
| | else: |
| | if self.otsu_sel: |
| | n_pixel = output_map.shape[-1] * output_map.shape[-2] |
| | image_pred = output_map.flatten(1) |
| | image_pred, _ = torch.sort(image_pred, dim=1) |
| | tmp = [] |
| | for b in range(bs): |
| | num_otsu_sel = get_otsu_k(image_pred[b, ...], sorted=True) |
| | num_otsu_sel = max(num_otsu_sel, n_pixel // 2 + 1) |
| | tpk = int(max(1, (n_pixel - num_otsu_sel) * self.otsu_portion)) |
| | topk_output = torch.topk(image_pred[b, ...], k=tpk, dim=0)[0] |
| | tmp.append(topk_output.mean()) |
| | image_pred = torch.stack(tmp) |
| | else: |
| | if self.map_portion == 0: |
| | image_pred = nn.functional.max_pool2d( |
| | output_map, kernel_size=output_map.shape[-2:] |
| | ) |
| | image_pred = image_pred.squeeze(1).squeeze(1).squeeze(1) |
| | else: |
| | n_pixel = output_map.shape[-1] * output_map.shape[-2] |
| | k = int(max(1, int(self.map_portion * n_pixel))) |
| | topk_output = torch.topk(output_map.flatten(1), k, dim=1)[0] |
| | image_pred = topk_output.mean(1) |
| |
|
| | if seg_size is not None: |
| | output_map = nn.functional.interpolate( |
| | output_map, size=seg_size, mode="bilinear", align_corners=False |
| | ) |
| | output_map = output_map.clamp(0, 1) |
| |
|
| | |
| | feature_map1 = self.proj_head1(encoder_feature[self.volume_block_idx]) |
| | if not self.share_embed_head: |
| | feature_map2 = self.proj_head2(encoder_feature[self.volume_block_idx]) |
| | else: |
| | feature_map2 = feature_map1.clone() |
| | b, c, h, w = feature_map1.shape |
| | feature_map1 = rearrange(feature_map1, "b c h w -> b c (h w)") |
| | feature_map2 = rearrange(feature_map2, "b c h w -> b c (h w)") |
| | consistency_volume = torch.bmm(feature_map1.transpose(-1, -2), feature_map2) |
| | consistency_volume = rearrange( |
| | consistency_volume, "b (h1 w1) (h2 w2) -> b h1 w1 h2 w2", h1=h, h2=h |
| | ) |
| | consistency_volume = consistency_volume / self.scale |
| | consistency_volume = 1 - consistency_volume.sigmoid() |
| |
|
| | vh, vw = consistency_volume.shape[-2:] |
| | if self.use_gem: |
| | volume_image_pred = consistency_volume.flatten(1) |
| | volume_image_pred = torch.linalg.norm( |
| | volume_image_pred, ord=self.gem_coef, dim=1 |
| | ) |
| | volume_image_pred = volume_image_pred / (vh * vw * vh * vw) |
| | elif self.use_gsm: |
| | volume_image_pred = consistency_volume.flatten(1) |
| | weight = project_onto_l1_ball(volume_image_pred, 1.0) |
| | volume_image_pred = (volume_image_pred * weight).sum(1) |
| | else: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if self.map_portion == 0: |
| | volume_image_pred = torch.max(consistency_volume.flatten(1), dim=1)[0] |
| | else: |
| | n_ele = vh * vw * vh * vw |
| | k = int(max(1, int(self.map_portion * n_ele))) |
| | topk_output = torch.topk(consistency_volume.flatten(1), k, dim=1)[0] |
| | volume_image_pred = topk_output.mean(1) |
| |
|
| | return { |
| | "out_map": output_map, |
| | "map_pred": image_pred, |
| | "out_vol": consistency_volume, |
| | "vol_pred": volume_image_pred, |
| | } |
| |
|
| |
|
| | def project_onto_l1_ball(x, eps): |
| | """ |
| | Compute Euclidean projection onto the L1 ball for a batch. |
| | |
| | min ||x - u||_2 s.t. ||u||_1 <= eps |
| | |
| | Inspired by the corresponding numpy version by Adrien Gaidon. |
| | |
| | Parameters |
| | ---------- |
| | x: (batch_size, *) torch array |
| | batch of arbitrary-size tensors to project, possibly on GPU |
| | |
| | eps: float |
| | radius of l-1 ball to project onto |
| | |
| | Returns |
| | ------- |
| | u: (batch_size, *) torch array |
| | batch of projected tensors, reshaped to match the original |
| | |
| | Notes |
| | ----- |
| | The complexity of this algorithm is in O(dlogd) as it involves sorting x. |
| | |
| | References |
| | ---------- |
| | [1] Efficient Projections onto the l1-Ball for Learning in High Dimensions |
| | John Duchi, Shai Shalev-Shwartz, Yoram Singer, and Tushar Chandra. |
| | International Conference on Machine Learning (ICML 2008) |
| | """ |
| | with torch.no_grad(): |
| | original_shape = x.shape |
| | x = x.view(x.shape[0], -1) |
| | mask = (torch.norm(x, p=1, dim=1) < eps).float().unsqueeze(1) |
| | mu, _ = torch.sort(torch.abs(x), dim=1, descending=True) |
| | cumsum = torch.cumsum(mu, dim=1) |
| | arange = torch.arange(1, x.shape[1] + 1, device=x.device) |
| | rho, _ = torch.max((mu * arange > (cumsum - eps)) * arange, dim=1) |
| | theta = (cumsum[torch.arange(x.shape[0]), rho.cpu() - 1] - eps) / rho |
| | proj = (torch.abs(x) - theta.unsqueeze(1)).clamp(min=0) |
| | x = mask * x + (1 - mask) * proj * torch.sign(x) |
| | x = x.view(original_shape) |
| | return x |
| |
|
| |
|
| | def get_otsu_k(attention, return_value=False, sorted=False): |
| | def _get_weighted_var(seq, pivot: int): |
| | |
| | length = seq.shape[0] |
| | wb = pivot / length |
| | vb = seq[:pivot].var() |
| | wf = 1 - pivot / length |
| | vf = seq[pivot:].var() |
| | return wb * vb + wf * vf |
| |
|
| | |
| | |
| | length = attention.shape[0] |
| | if length == 1: |
| | return 0 |
| | elif length == 2: |
| | return 1 |
| | if not sorted: |
| | attention, _ = torch.sort(attention) |
| | optimal_i = length // 2 |
| | min_intra_class_var = _get_weighted_var(attention, optimal_i) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | got_it = False |
| | |
| | for i in range(optimal_i - 1, 0, -1): |
| | intra_class_var = _get_weighted_var(attention, i) |
| | if intra_class_var > min_intra_class_var: |
| | break |
| | else: |
| | min_intra_class_var = intra_class_var |
| | optimal_i = i |
| | got_it = True |
| | |
| | if not got_it: |
| | for i in range(optimal_i + 1, length): |
| | intra_class_var = _get_weighted_var(attention, i) |
| | if intra_class_var > min_intra_class_var: |
| | break |
| | else: |
| | min_intra_class_var = intra_class_var |
| | optimal_i = i |
| |
|
| | if return_value: |
| | return attention[optimal_i] |
| | else: |
| | return optimal_i |
| |
|
| |
|
| | if __name__ == "__main__": |
| | model = MainModel(None, None, 1024, 2, True, "srm") |
| |
|