phanerozoic commited on
Commit
cd7a8ba
·
verified ·
1 Parent(s): d2251ba

Add tropical, compression, curvature, renormalization heads

Browse files
heads/__init__.py CHANGED
@@ -9,6 +9,9 @@ from .patch_attention.head import PatchAttention
9
  from .graph_crf.head import GraphCRF
10
  from .hypercolumn_linear.head import HypercolumnLinear
11
  from .info_bottleneck.head import InfoBottleneck
 
 
 
12
 
13
  REGISTRY = {
14
  "linear_probe": LinearProbe,
@@ -20,6 +23,9 @@ REGISTRY = {
20
  "graph_crf": GraphCRF,
21
  "hypercolumn_linear": HypercolumnLinear,
22
  "info_bottleneck": InfoBottleneck,
 
 
 
23
  }
24
 
25
  ALL_NAMES = list(REGISTRY.keys())
 
9
  from .graph_crf.head import GraphCRF
10
  from .hypercolumn_linear.head import HypercolumnLinear
11
  from .info_bottleneck.head import InfoBottleneck
12
+ from .tropical.head import TropicalSegmentation
13
+ from .compression.head import CompressionSegmentation
14
+ from .curvature.head import CurvatureSegmentation
15
 
16
  REGISTRY = {
17
  "linear_probe": LinearProbe,
 
23
  "graph_crf": GraphCRF,
24
  "hypercolumn_linear": HypercolumnLinear,
25
  "info_bottleneck": InfoBottleneck,
26
+ "tropical": TropicalSegmentation,
27
+ "compression": CompressionSegmentation,
28
+ "curvature": CurvatureSegmentation,
29
  }
30
 
31
  ALL_NAMES = list(REGISTRY.keys())
heads/compression/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/compression/head.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Compression Segmentation: modulate features by local prediction residual.
2
+ Patches that can't be predicted from neighbors get amplified before classification."""
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class CompressionSegmentation(nn.Module):
10
+ name = "compression"
11
+ needs_intermediates = False
12
+
13
+ def __init__(self, feat_dim=768, num_classes=150):
14
+ super().__init__()
15
+ self.cls = nn.Conv2d(feat_dim, num_classes, 1)
16
+
17
+ def forward(self, spatial, inter=None):
18
+ B, C, H, W = spatial.shape
19
+ kernel = torch.tensor([[1, 1, 1], [1, 0, 1], [1, 1, 1]],
20
+ dtype=spatial.dtype, device=spatial.device) / 8
21
+ kernel = kernel.reshape(1, 1, 3, 3).expand(C, 1, 3, 3)
22
+ neighbor_mean = F.conv2d(spatial, kernel, padding=1, groups=C)
23
+ surprise = (spatial - neighbor_mean).pow(2).sum(dim=1, keepdim=True)
24
+ surprise_norm = surprise / surprise.amax(dim=(2, 3), keepdim=True).clamp(min=1e-6)
25
+ modulated = spatial * (1 + surprise_norm * 3)
26
+ return self.cls(modulated)
heads/curvature/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/curvature/head.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Curvature Segmentation: modulate features by discrete Riemannian curvature.
2
+ High-curvature locations (where the feature manifold bends sharply) are
3
+ segment boundaries. Features at those locations get amplified."""
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class CurvatureSegmentation(nn.Module):
11
+ name = "curvature"
12
+ needs_intermediates = False
13
+
14
+ def __init__(self, feat_dim=768, num_classes=150):
15
+ super().__init__()
16
+ self.cls = nn.Conv2d(feat_dim, num_classes, 1)
17
+
18
+ def _curvature(self, spatial):
19
+ B, C, H, W = spatial.shape
20
+ dx = torch.zeros_like(spatial)
21
+ dx[:, :, :, 1:-1] = (spatial[:, :, :, 2:] - spatial[:, :, :, :-2]) / 2
22
+ dy = torch.zeros_like(spatial)
23
+ dy[:, :, 1:-1, :] = (spatial[:, :, 2:, :] - spatial[:, :, :-2, :]) / 2
24
+ g11 = (dx * dx).sum(dim=1, keepdim=True)
25
+ g22 = (dy * dy).sum(dim=1, keepdim=True)
26
+ det_g = (g11 * g22).clamp(min=1e-10)
27
+ dxx = torch.zeros_like(spatial)
28
+ dxx[:, :, :, 1:-1] = spatial[:, :, :, 2:] - 2 * spatial[:, :, :, 1:-1] + spatial[:, :, :, :-2]
29
+ dyy = torch.zeros_like(spatial)
30
+ dyy[:, :, 1:-1, :] = spatial[:, :, 2:, :] - 2 * spatial[:, :, 1:-1, :] + spatial[:, :, :-2, :]
31
+ laplacian = dxx + dyy
32
+ return laplacian.pow(2).sum(dim=1, keepdim=True) / det_g
33
+
34
+ def forward(self, spatial, inter=None):
35
+ curv = self._curvature(spatial)
36
+ curv_norm = curv / curv.amax(dim=(2, 3), keepdim=True).clamp(min=1e-6)
37
+ modulated = spatial * (1 + curv_norm * 3)
38
+ return self.cls(modulated)
heads/tropical/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .head import *
heads/tropical/head.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tropical Segmentation: tropical inner product replaces standard dot product.
2
+ logit_c = min_d(w_{c,d} + x_d) instead of sum(w_{c,d} * x_d)."""
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class TropicalSegmentation(nn.Module):
9
+ name = "tropical"
10
+ needs_intermediates = False
11
+
12
+ def __init__(self, feat_dim=768, num_classes=150):
13
+ super().__init__()
14
+ self.tropical_w = nn.Parameter(torch.randn(num_classes, feat_dim) * 0.01)
15
+ self.bias = nn.Parameter(torch.zeros(num_classes))
16
+
17
+ def forward(self, spatial, inter=None):
18
+ B, C, H, W = spatial.shape
19
+ f = spatial.permute(0, 2, 3, 1).reshape(-1, C)
20
+ # Tropical: min_d(w_{c,d} + f_d), smooth approx via -logsumexp(-beta*(w+f))/beta
21
+ beta = 10.0
22
+ expanded = self.tropical_w.unsqueeze(0) + f.unsqueeze(1)
23
+ logits = -torch.logsumexp(-beta * expanded, dim=2) / beta + self.bias
24
+ return logits.reshape(B, H, W, -1).permute(0, 3, 1, 2)