Add tropical, compression, curvature, renormalization heads
Browse files- heads/__init__.py +6 -0
- heads/compression/__init__.py +1 -0
- heads/compression/head.py +26 -0
- heads/curvature/__init__.py +1 -0
- heads/curvature/head.py +38 -0
- heads/tropical/__init__.py +1 -0
- heads/tropical/head.py +24 -0
heads/__init__.py
CHANGED
|
@@ -9,6 +9,9 @@ from .patch_attention.head import PatchAttention
|
|
| 9 |
from .graph_crf.head import GraphCRF
|
| 10 |
from .hypercolumn_linear.head import HypercolumnLinear
|
| 11 |
from .info_bottleneck.head import InfoBottleneck
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
REGISTRY = {
|
| 14 |
"linear_probe": LinearProbe,
|
|
@@ -20,6 +23,9 @@ REGISTRY = {
|
|
| 20 |
"graph_crf": GraphCRF,
|
| 21 |
"hypercolumn_linear": HypercolumnLinear,
|
| 22 |
"info_bottleneck": InfoBottleneck,
|
|
|
|
|
|
|
|
|
|
| 23 |
}
|
| 24 |
|
| 25 |
ALL_NAMES = list(REGISTRY.keys())
|
|
|
|
| 9 |
from .graph_crf.head import GraphCRF
|
| 10 |
from .hypercolumn_linear.head import HypercolumnLinear
|
| 11 |
from .info_bottleneck.head import InfoBottleneck
|
| 12 |
+
from .tropical.head import TropicalSegmentation
|
| 13 |
+
from .compression.head import CompressionSegmentation
|
| 14 |
+
from .curvature.head import CurvatureSegmentation
|
| 15 |
|
| 16 |
REGISTRY = {
|
| 17 |
"linear_probe": LinearProbe,
|
|
|
|
| 23 |
"graph_crf": GraphCRF,
|
| 24 |
"hypercolumn_linear": HypercolumnLinear,
|
| 25 |
"info_bottleneck": InfoBottleneck,
|
| 26 |
+
"tropical": TropicalSegmentation,
|
| 27 |
+
"compression": CompressionSegmentation,
|
| 28 |
+
"curvature": CurvatureSegmentation,
|
| 29 |
}
|
| 30 |
|
| 31 |
ALL_NAMES = list(REGISTRY.keys())
|
heads/compression/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .head import *
|
heads/compression/head.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Compression Segmentation: modulate features by local prediction residual.
|
| 2 |
+
Patches that can't be predicted from neighbors get amplified before classification."""
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CompressionSegmentation(nn.Module):
|
| 10 |
+
name = "compression"
|
| 11 |
+
needs_intermediates = False
|
| 12 |
+
|
| 13 |
+
def __init__(self, feat_dim=768, num_classes=150):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.cls = nn.Conv2d(feat_dim, num_classes, 1)
|
| 16 |
+
|
| 17 |
+
def forward(self, spatial, inter=None):
|
| 18 |
+
B, C, H, W = spatial.shape
|
| 19 |
+
kernel = torch.tensor([[1, 1, 1], [1, 0, 1], [1, 1, 1]],
|
| 20 |
+
dtype=spatial.dtype, device=spatial.device) / 8
|
| 21 |
+
kernel = kernel.reshape(1, 1, 3, 3).expand(C, 1, 3, 3)
|
| 22 |
+
neighbor_mean = F.conv2d(spatial, kernel, padding=1, groups=C)
|
| 23 |
+
surprise = (spatial - neighbor_mean).pow(2).sum(dim=1, keepdim=True)
|
| 24 |
+
surprise_norm = surprise / surprise.amax(dim=(2, 3), keepdim=True).clamp(min=1e-6)
|
| 25 |
+
modulated = spatial * (1 + surprise_norm * 3)
|
| 26 |
+
return self.cls(modulated)
|
heads/curvature/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .head import *
|
heads/curvature/head.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Curvature Segmentation: modulate features by discrete Riemannian curvature.
|
| 2 |
+
High-curvature locations (where the feature manifold bends sharply) are
|
| 3 |
+
segment boundaries. Features at those locations get amplified."""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CurvatureSegmentation(nn.Module):
|
| 11 |
+
name = "curvature"
|
| 12 |
+
needs_intermediates = False
|
| 13 |
+
|
| 14 |
+
def __init__(self, feat_dim=768, num_classes=150):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.cls = nn.Conv2d(feat_dim, num_classes, 1)
|
| 17 |
+
|
| 18 |
+
def _curvature(self, spatial):
|
| 19 |
+
B, C, H, W = spatial.shape
|
| 20 |
+
dx = torch.zeros_like(spatial)
|
| 21 |
+
dx[:, :, :, 1:-1] = (spatial[:, :, :, 2:] - spatial[:, :, :, :-2]) / 2
|
| 22 |
+
dy = torch.zeros_like(spatial)
|
| 23 |
+
dy[:, :, 1:-1, :] = (spatial[:, :, 2:, :] - spatial[:, :, :-2, :]) / 2
|
| 24 |
+
g11 = (dx * dx).sum(dim=1, keepdim=True)
|
| 25 |
+
g22 = (dy * dy).sum(dim=1, keepdim=True)
|
| 26 |
+
det_g = (g11 * g22).clamp(min=1e-10)
|
| 27 |
+
dxx = torch.zeros_like(spatial)
|
| 28 |
+
dxx[:, :, :, 1:-1] = spatial[:, :, :, 2:] - 2 * spatial[:, :, :, 1:-1] + spatial[:, :, :, :-2]
|
| 29 |
+
dyy = torch.zeros_like(spatial)
|
| 30 |
+
dyy[:, :, 1:-1, :] = spatial[:, :, 2:, :] - 2 * spatial[:, :, 1:-1, :] + spatial[:, :, :-2, :]
|
| 31 |
+
laplacian = dxx + dyy
|
| 32 |
+
return laplacian.pow(2).sum(dim=1, keepdim=True) / det_g
|
| 33 |
+
|
| 34 |
+
def forward(self, spatial, inter=None):
|
| 35 |
+
curv = self._curvature(spatial)
|
| 36 |
+
curv_norm = curv / curv.amax(dim=(2, 3), keepdim=True).clamp(min=1e-6)
|
| 37 |
+
modulated = spatial * (1 + curv_norm * 3)
|
| 38 |
+
return self.cls(modulated)
|
heads/tropical/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .head import *
|
heads/tropical/head.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tropical Segmentation: tropical inner product replaces standard dot product.
|
| 2 |
+
logit_c = min_d(w_{c,d} + x_d) instead of sum(w_{c,d} * x_d)."""
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TropicalSegmentation(nn.Module):
|
| 9 |
+
name = "tropical"
|
| 10 |
+
needs_intermediates = False
|
| 11 |
+
|
| 12 |
+
def __init__(self, feat_dim=768, num_classes=150):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.tropical_w = nn.Parameter(torch.randn(num_classes, feat_dim) * 0.01)
|
| 15 |
+
self.bias = nn.Parameter(torch.zeros(num_classes))
|
| 16 |
+
|
| 17 |
+
def forward(self, spatial, inter=None):
|
| 18 |
+
B, C, H, W = spatial.shape
|
| 19 |
+
f = spatial.permute(0, 2, 3, 1).reshape(-1, C)
|
| 20 |
+
# Tropical: min_d(w_{c,d} + f_d), smooth approx via -logsumexp(-beta*(w+f))/beta
|
| 21 |
+
beta = 10.0
|
| 22 |
+
expanded = self.tropical_w.unsqueeze(0) + f.unsqueeze(1)
|
| 23 |
+
logits = -torch.logsumexp(-beta * expanded, dim=2) / beta + self.bias
|
| 24 |
+
return logits.reshape(B, H, W, -1).permute(0, 3, 1, 2)
|