argus / argus.py
phanerozoic's picture
argus.py: factor padding_mode through DPT blocks; add depth crop_border kwarg; batched correspond. README: align with shipped 3M cofiber detection head, drop FCOS framing, fix file sizes/param counts, add real IN1k val comparison, document qkv-bias choice. eval JSON: strip personal paths.
5eaa5db verified
"""
Argus: multi-task perception on a single EUPE-ViT-B backbone.
from transformers import AutoModel
model = AutoModel.from_pretrained("phanerozoic/argus", trust_remote_code=True)
result = model.perceive(image)
The EUPE-ViT-B backbone architecture, all supporting layers, and the Argus
task heads are inlined below. The backbone code is reproduced from
facebookresearch/EUPE (Meta FAIR) under the FAIR Research License.
"""
import math
import time
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn.init
from PIL import Image
from torch import Tensor, nn
from torchvision.ops import nms
from torchvision.transforms import v2
from transformers import PretrainedConfig, PreTrainedModel
# ===========================================================================
# EUPE backbone — vendored verbatim from facebookresearch/EUPE
# ===========================================================================
# ---------- utility helpers (from eupe/utils/utils.py) ---------------------
def cat_keep_shapes(x_list: List[Tensor]) -> Tuple[Tensor, List[Tuple[int]], List[int]]:
shapes = [x.shape for x in x_list]
num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list]
flattened = torch.cat([x.flatten(0, -2) for x in x_list])
return flattened, shapes, num_tokens
def uncat_with_shapes(flattened: Tensor, shapes: List[Tuple[int]], num_tokens: List[int]) -> List[Tensor]:
outputs_splitted = torch.split_with_sizes(flattened, num_tokens, dim=0)
shapes_adjusted = [shape[:-1] + torch.Size([flattened.shape[-1]]) for shape in shapes]
outputs_reshaped = [o.reshape(shape) for o, shape in zip(outputs_splitted, shapes_adjusted)]
return outputs_reshaped
def named_apply(
fn: Callable,
module: nn.Module,
name: str = "",
depth_first: bool = True,
include_root: bool = False,
) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = ".".join((name, child_name)) if name else child_name
named_apply(
fn=fn,
module=child_module,
name=child_name,
depth_first=depth_first,
include_root=True,
)
if depth_first and include_root:
fn(module=module, name=name)
return module
# ---------- RMSNorm (from eupe/layers/rms_norm.py) -------------------------
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def reset_parameters(self) -> None:
nn.init.constant_(self.weight, 1)
def _norm(self, x: Tensor) -> Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
# ---------- LayerScale (from eupe/layers/layer_scale.py) -------------------
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: Union[float, Tensor] = 1e-5,
inplace: bool = False,
device=None,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(torch.empty(dim, device=device))
self.init_values = init_values
def reset_parameters(self):
nn.init.constant_(self.gamma, self.init_values)
def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma
# ---------- PatchEmbed (from eupe/layers/patch_embed.py) -------------------
def make_2tuple(x):
if isinstance(x, tuple):
assert len(x) == 2
return x
assert isinstance(x, int)
return (x, x)
class PatchEmbed(nn.Module):
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten_embedding: bool = True,
) -> None:
super().__init__()
image_HW = make_2tuple(img_size)
patch_HW = make_2tuple(patch_size)
patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
self.img_size = image_HW
self.patch_size = patch_HW
self.patches_resolution = patch_grid_size
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.flatten_embedding = flatten_embedding
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
_, _, H, W = x.shape
x = self.proj(x)
H, W = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(-1, H, W, self.embed_dim)
return x
def reset_parameters(self):
k = 1 / (self.in_chans * (self.patch_size[0] ** 2))
nn.init.uniform_(self.proj.weight, -math.sqrt(k), math.sqrt(k))
if self.proj.bias is not None:
nn.init.uniform_(self.proj.bias, -math.sqrt(k), math.sqrt(k))
# ---------- RoPE (from eupe/layers/rope_position_encoding.py) --------------
class RopePositionEmbedding(nn.Module):
def __init__(
self,
embed_dim: int,
*,
num_heads: int,
base: Optional[float] = 100.0,
min_period: Optional[float] = None,
max_period: Optional[float] = None,
normalize_coords: Literal["min", "max", "separate"] = "separate",
shift_coords: Optional[float] = None,
jitter_coords: Optional[float] = None,
rescale_coords: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
super().__init__()
assert embed_dim % (4 * num_heads) == 0
both_periods = min_period is not None and max_period is not None
if (base is None and not both_periods) or (base is not None and both_periods):
raise ValueError("Either `base` or `min_period`+`max_period` must be provided.")
D_head = embed_dim // num_heads
self.base = base
self.min_period = min_period
self.max_period = max_period
self.D_head = D_head
self.normalize_coords = normalize_coords
self.shift_coords = shift_coords
self.jitter_coords = jitter_coords
self.rescale_coords = rescale_coords
self.dtype = dtype
self.register_buffer(
"periods",
torch.empty(D_head // 4, device=device, dtype=dtype),
persistent=True,
)
self._init_weights()
def forward(self, *, H: int, W: int) -> Tuple[Tensor, Tensor]:
device = self.periods.device
dtype = self.dtype
dd = {"device": device, "dtype": dtype}
if self.normalize_coords == "max":
max_HW = max(H, W)
coords_h = torch.arange(0.5, H, **dd) / max_HW
coords_w = torch.arange(0.5, W, **dd) / max_HW
elif self.normalize_coords == "min":
min_HW = min(H, W)
coords_h = torch.arange(0.5, H, **dd) / min_HW
coords_w = torch.arange(0.5, W, **dd) / min_HW
elif self.normalize_coords == "separate":
coords_h = torch.arange(0.5, H, **dd) / H
coords_w = torch.arange(0.5, W, **dd) / W
else:
raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
coords = coords.flatten(0, 1)
coords = 2.0 * coords - 1.0
if self.training and self.shift_coords is not None:
shift_hw = torch.empty(2, **dd).uniform_(-self.shift_coords, self.shift_coords)
coords += shift_hw[None, :]
if self.training and self.jitter_coords is not None:
jitter_max = np.log(self.jitter_coords)
jitter_min = -jitter_max
jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp()
coords *= jitter_hw[None, :]
if self.training and self.rescale_coords is not None:
rescale_max = np.log(self.rescale_coords)
rescale_min = -rescale_max
rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp()
coords *= rescale_hw
angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :]
angles = angles.flatten(1, 2)
angles = angles.tile(2)
cos = torch.cos(angles)
sin = torch.sin(angles)
return (sin, cos)
def _init_weights(self):
device = self.periods.device
dtype = self.dtype
if self.base is not None:
periods = self.base ** (
2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2)
)
else:
base = self.max_period / self.min_period
exponents = torch.linspace(0, 1, self.D_head // 4, device=device, dtype=dtype)
periods = base ** exponents
periods = periods / base
periods = periods * self.max_period
self.periods.data = periods
# ---------- FFN layers (from eupe/layers/ffn_layers.py) --------------------
class ListForwardMixin(object):
def forward(self, x: Tensor):
raise NotImplementedError
def forward_list(self, x_list: List[Tensor]) -> List[Tensor]:
x_flat, shapes, num_tokens = cat_keep_shapes(x_list)
x_flat = self.forward(x_flat)
return uncat_with_shapes(x_flat, shapes, num_tokens)
class Mlp(nn.Module, ListForwardMixin):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
bias: bool = True,
device=None,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, device=device)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, device=device)
self.drop = nn.Dropout(drop)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class SwiGLUFFN(nn.Module, ListForwardMixin):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Optional[Callable[..., nn.Module]] = None,
drop: float = 0.0,
bias: bool = True,
align_to: int = 8,
device=None,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
d = int(hidden_features * 2 / 3)
swiglu_hidden_features = d + (-d % align_to)
self.w1 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device)
self.w2 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device)
self.w3 = nn.Linear(swiglu_hidden_features, out_features, bias=bias, device=device)
def forward(self, x: Tensor) -> Tensor:
x1 = self.w1(x)
x2 = self.w2(x)
hidden = F.silu(x1) * x2
return self.w3(hidden)
# ---------- Attention (from eupe/layers/attention.py) ----------------------
def rope_rotate_half(x: Tensor) -> Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor:
return (x * cos) + (rope_rotate_half(x) * sin)
class LinearKMaskedBias(nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
o = self.out_features
assert o % 3 == 0
if self.bias is not None:
self.register_buffer("bias_mask", torch.full_like(self.bias, fill_value=math.nan))
def forward(self, input: Tensor) -> Tensor:
masked_bias = self.bias * self.bias_mask.to(self.bias.dtype) if self.bias is not None else None
return F.linear(input, self.weight, masked_bias)
class SelfAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
mask_k_bias: bool = False,
device=None,
) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
linear_class = LinearKMaskedBias if mask_k_bias else nn.Linear
self.qkv = linear_class(dim, dim * 3, bias=qkv_bias, device=device)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias, device=device)
self.proj_drop = nn.Dropout(proj_drop)
def apply_rope(self, q: Tensor, k: Tensor, rope) -> Tuple[Tensor, Tensor]:
q_dtype = q.dtype
k_dtype = k.dtype
sin, cos = rope
rope_dtype = sin.dtype
q = q.to(dtype=rope_dtype)
k = k.to(dtype=rope_dtype)
N = q.shape[-2]
prefix = N - sin.shape[-2]
assert prefix >= 0
q_prefix = q[:, :, :prefix, :]
q = rope_apply(q[:, :, prefix:, :], sin, cos)
q = torch.cat((q_prefix, q), dim=-2)
k_prefix = k[:, :, :prefix, :]
k = rope_apply(k[:, :, prefix:, :], sin, cos)
k = torch.cat((k_prefix, k), dim=-2)
q = q.to(dtype=q_dtype)
k = k.to(dtype=k_dtype)
return q, k
def forward(self, x: Tensor, attn_bias=None, rope=None) -> Tensor:
qkv = self.qkv(x)
attn_v = self.compute_attention(qkv=qkv, attn_bias=attn_bias, rope=rope)
x = self.proj(attn_v)
x = self.proj_drop(x)
return x
def forward_list(self, x_list, attn_bias=None, rope_list=None) -> List[Tensor]:
assert len(x_list) == len(rope_list)
x_flat, shapes, num_tokens = cat_keep_shapes(x_list)
qkv_flat = self.qkv(x_flat)
qkv_list = uncat_with_shapes(qkv_flat, shapes, num_tokens)
att_out = []
for _, (qkv, _, rope) in enumerate(zip(qkv_list, shapes, rope_list)):
att_out.append(self.compute_attention(qkv, attn_bias=attn_bias, rope=rope))
x_flat, shapes, num_tokens = cat_keep_shapes(att_out)
x_flat = self.proj(x_flat)
return uncat_with_shapes(x_flat, shapes, num_tokens)
def compute_attention(self, qkv: Tensor, attn_bias=None, rope=None) -> Tensor:
assert attn_bias is None
B, N, _ = qkv.shape
C = self.qkv.in_features
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = torch.unbind(qkv, 2)
q, k, v = [t.transpose(1, 2) for t in [q, k, v]]
if rope is not None:
q, k = self.apply_rope(q, k, rope)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = x.transpose(1, 2)
return x.reshape([B, N, C])
# ---------- Block (from eupe/layers/block.py) ------------------------------
class SelfAttentionBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
ffn_ratio: float = 4.0,
qkv_bias: bool = False,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = SelfAttention,
ffn_layer: Callable[..., nn.Module] = Mlp,
mask_k_bias: bool = False,
device=None,
) -> None:
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
mask_k_bias=mask_k_bias,
device=device,
)
self.ls1 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * ffn_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias,
device=device,
)
self.ls2 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity()
self.sample_drop_ratio = drop_path
@staticmethod
def _maybe_index_rope(rope, indices: Tensor):
if rope is None:
return None
sin, cos = rope
assert sin.ndim == cos.ndim
if sin.ndim == 4:
return sin[indices], cos[indices]
return sin, cos
def _forward_list(self, x_list: List[Tensor], rope_list=None) -> List[Tensor]:
b_list = [x.shape[0] for x in x_list]
sample_subset_sizes = [max(int(b * (1 - self.sample_drop_ratio)), 1) for b in b_list]
if self.training and self.sample_drop_ratio > 0.0:
residual_scale_factors = [b / s for b, s in zip(b_list, sample_subset_sizes)]
indices_1_list = [
torch.randperm(b, device=x.device)[:s]
for x, b, s in zip(x_list, b_list, sample_subset_sizes)
]
x_subset_1_list = [x[i] for x, i in zip(x_list, indices_1_list)]
if rope_list is not None:
rope_subset_list = [
self._maybe_index_rope(r, i) for r, i in zip(rope_list, indices_1_list)
]
else:
rope_subset_list = rope_list
flattened, shapes, num_tokens = cat_keep_shapes(x_subset_1_list)
norm1 = uncat_with_shapes(self.norm1(flattened), shapes, num_tokens)
residual_1_list = self.attn.forward_list(norm1, rope_list=rope_subset_list)
x_attn_list = [
torch.index_add(x, dim=0, source=self.ls1(r1), index=i1, alpha=rsf)
for x, r1, i1, rsf in zip(x_list, residual_1_list, indices_1_list, residual_scale_factors)
]
indices_2_list = [
torch.randperm(b, device=x.device)[:s]
for x, b, s in zip(x_list, b_list, sample_subset_sizes)
]
x_subset_2_list = [x[i] for x, i in zip(x_attn_list, indices_2_list)]
flattened, shapes, num_tokens = cat_keep_shapes(x_subset_2_list)
norm2_list = uncat_with_shapes(self.norm2(flattened), shapes, num_tokens)
residual_2_list = self.mlp.forward_list(norm2_list)
x_ffn = [
torch.index_add(xa, dim=0, source=self.ls2(r2), index=i2, alpha=rsf)
for xa, r2, i2, rsf in zip(x_attn_list, residual_2_list, indices_2_list, residual_scale_factors)
]
else:
x_out = []
for x, rope in zip(x_list, rope_list):
x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope))
x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn)))
x_out.append(x_ffn)
x_ffn = x_out
return x_ffn
def forward(self, x_or_x_list, rope_or_rope_list=None) -> List[Tensor]:
if isinstance(x_or_x_list, Tensor):
return self._forward_list([x_or_x_list], rope_list=[rope_or_rope_list])[0]
elif isinstance(x_or_x_list, list):
if rope_or_rope_list is None:
rope_or_rope_list = [None for _ in x_or_x_list]
return self._forward_list(x_or_x_list, rope_list=rope_or_rope_list)
raise AssertionError
# ---------- DinoVisionTransformer (from eupe/models/vision_transformer.py)
ffn_layer_dict = {
"mlp": Mlp,
"swiglu": SwiGLUFFN,
"swiglu32": partial(SwiGLUFFN, align_to=32),
"swiglu64": partial(SwiGLUFFN, align_to=64),
"swiglu128": partial(SwiGLUFFN, align_to=128),
}
norm_layer_dict = {
"layernorm": partial(nn.LayerNorm, eps=1e-6),
"layernormbf16": partial(nn.LayerNorm, eps=1e-5),
"rmsnorm": RMSNorm,
}
dtype_dict = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
def init_weights_vit(module: nn.Module, name: str = ""):
if isinstance(module, nn.Linear):
torch.nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
if hasattr(module, "bias_mask") and module.bias_mask is not None:
o = module.out_features
module.bias_mask.fill_(1)
module.bias_mask[o // 3 : 2 * o // 3].fill_(0)
if isinstance(module, nn.LayerNorm):
module.reset_parameters()
if isinstance(module, LayerScale):
module.reset_parameters()
if isinstance(module, PatchEmbed):
module.reset_parameters()
if isinstance(module, RMSNorm):
module.reset_parameters()
class DinoVisionTransformer(nn.Module):
def __init__(
self,
*,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
pos_embed_rope_base: float = 100.0,
pos_embed_rope_min_period: Optional[float] = None,
pos_embed_rope_max_period: Optional[float] = None,
pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate",
pos_embed_rope_shift_coords: Optional[float] = None,
pos_embed_rope_jitter_coords: Optional[float] = None,
pos_embed_rope_rescale_coords: Optional[float] = None,
pos_embed_rope_dtype: str = "bf16",
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
ffn_ratio: float = 4.0,
qkv_bias: bool = True,
drop_path_rate: float = 0.0,
layerscale_init: Optional[float] = None,
norm_layer: str = "layernorm",
ffn_layer: str = "mlp",
ffn_bias: bool = True,
proj_bias: bool = True,
n_storage_tokens: int = 0,
mask_k_bias: bool = False,
untie_cls_and_patch_norms: bool = False,
untie_global_and_local_cls_norm: bool = False,
device: Any = None,
**ignored_kwargs,
):
super().__init__()
del ignored_kwargs
norm_layer_cls = norm_layer_dict[norm_layer]
self.num_features = self.embed_dim = embed_dim
self.n_blocks = depth
self.num_heads = num_heads
self.patch_size = patch_size
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
flatten_embedding=False,
)
self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, device=device))
self.n_storage_tokens = n_storage_tokens
if self.n_storage_tokens > 0:
self.storage_tokens = nn.Parameter(torch.empty(1, n_storage_tokens, embed_dim, device=device))
self.rope_embed = RopePositionEmbedding(
embed_dim=embed_dim,
num_heads=num_heads,
base=pos_embed_rope_base,
min_period=pos_embed_rope_min_period,
max_period=pos_embed_rope_max_period,
normalize_coords=pos_embed_rope_normalize_coords,
shift_coords=pos_embed_rope_shift_coords,
jitter_coords=pos_embed_rope_jitter_coords,
rescale_coords=pos_embed_rope_rescale_coords,
dtype=dtype_dict[pos_embed_rope_dtype],
device=device,
)
ffn_layer_cls = ffn_layer_dict[ffn_layer]
ffn_ratio_sequence = [ffn_ratio] * depth
blocks_list = [
SelfAttentionBlock(
dim=embed_dim,
num_heads=num_heads,
ffn_ratio=ffn_ratio_sequence[i],
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
drop_path=drop_path_rate,
norm_layer=norm_layer_cls,
act_layer=nn.GELU,
ffn_layer=ffn_layer_cls,
init_values=layerscale_init,
mask_k_bias=mask_k_bias,
device=device,
)
for i in range(depth)
]
self.chunked_blocks = False
self.blocks = nn.ModuleList(blocks_list)
self.norm = norm_layer_cls(embed_dim)
self.untie_cls_and_patch_norms = untie_cls_and_patch_norms
self.cls_norm = norm_layer_cls(embed_dim) if untie_cls_and_patch_norms else None
self.untie_global_and_local_cls_norm = untie_global_and_local_cls_norm
self.local_cls_norm = norm_layer_cls(embed_dim) if untie_global_and_local_cls_norm else None
self.head = nn.Identity()
self.mask_token = nn.Parameter(torch.empty(1, embed_dim, device=device))
def init_weights(self):
self.rope_embed._init_weights()
nn.init.normal_(self.cls_token, std=0.02)
if self.n_storage_tokens > 0:
nn.init.normal_(self.storage_tokens, std=0.02)
nn.init.zeros_(self.mask_token)
named_apply(init_weights_vit, self)
def prepare_tokens_with_masks(self, x: Tensor, masks=None) -> Tuple[Tensor, Tuple[int, int]]:
x = self.patch_embed(x)
B, H, W, _ = x.shape
x = x.flatten(1, 2)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
cls_token = self.cls_token
else:
cls_token = self.cls_token + 0 * self.mask_token
if self.n_storage_tokens > 0:
storage_tokens = self.storage_tokens
else:
storage_tokens = torch.empty(
1, 0, cls_token.shape[-1],
dtype=cls_token.dtype, device=cls_token.device,
)
x = torch.cat(
[cls_token.expand(B, -1, -1), storage_tokens.expand(B, -1, -1), x],
dim=1,
)
return x, (H, W)
def forward_features_list(self, x_list: List[Tensor], masks_list: List[Tensor]) -> List[Dict[str, Tensor]]:
x = []
rope = []
for t_x, t_masks in zip(x_list, masks_list):
t2_x, hw_tuple = self.prepare_tokens_with_masks(t_x, t_masks)
x.append(t2_x)
rope.append(hw_tuple)
for blk in self.blocks:
if self.rope_embed is not None:
rope_sincos = [self.rope_embed(H=H, W=W) for H, W in rope]
else:
rope_sincos = [None for _ in rope]
x = blk(x, rope_sincos)
all_x = x
output = []
for idx, (x, masks) in enumerate(zip(all_x, masks_list)):
if self.untie_cls_and_patch_norms or self.untie_global_and_local_cls_norm:
if self.untie_global_and_local_cls_norm and self.training and idx == 1:
x_norm_cls_reg = self.local_cls_norm(x[:, : self.n_storage_tokens + 1])
elif self.untie_cls_and_patch_norms:
x_norm_cls_reg = self.cls_norm(x[:, : self.n_storage_tokens + 1])
else:
x_norm_cls_reg = self.norm(x[:, : self.n_storage_tokens + 1])
x_norm_patch = self.norm(x[:, self.n_storage_tokens + 1 :])
else:
x_norm = self.norm(x)
x_norm_cls_reg = x_norm[:, : self.n_storage_tokens + 1]
x_norm_patch = x_norm[:, self.n_storage_tokens + 1 :]
output.append({
"x_norm_clstoken": x_norm_cls_reg[:, 0],
"x_storage_tokens": x_norm_cls_reg[:, 1:],
"x_norm_patchtokens": x_norm_patch,
"x_prenorm": x,
"masks": masks,
})
return output
def forward_features(self, x, masks: Optional[Tensor] = None):
if isinstance(x, torch.Tensor):
return self.forward_features_list([x], [masks])[0]
return self.forward_features_list(x, masks)
def forward(self, *args, is_training: bool = False, **kwargs):
ret = self.forward_features(*args, **kwargs)
if is_training:
return ret
return self.head(ret["x_norm_clstoken"])
def build_eupe_vitb16() -> DinoVisionTransformer:
# qkv_bias=False, mask_k_bias=False: the upstream EUPE-ViT-B release shipped
# with `qkv.bias_mask` filled with zeros, which makes the effective qkv bias
# zero at every block (masked_bias = bias * 0 = 0). We drop the bias parameter
# entirely here — the computation is bitwise-equivalent in fp32, bf16 output
# drift is sub-ULP and absorbed by every head except DPT depth (where it
# appears as ~2cm noise against a 39cm RMSE, i.e. below the head's own floor).
return DinoVisionTransformer(
img_size=224,
patch_size=16,
in_chans=3,
pos_embed_rope_base=100,
pos_embed_rope_normalize_coords="separate",
pos_embed_rope_rescale_coords=2,
pos_embed_rope_dtype="fp32",
embed_dim=768,
depth=12,
num_heads=12,
ffn_ratio=4,
qkv_bias=False,
drop_path_rate=0.0,
layerscale_init=1.0e-05,
norm_layer="layernormbf16",
ffn_layer="mlp",
ffn_bias=True,
proj_bias=True,
n_storage_tokens=4,
mask_k_bias=False,
)
# ===========================================================================
# Argus task heads
# ===========================================================================
def make_eupe_transform(resize_size: int):
return v2.Compose([
v2.ToImage(),
v2.Resize((resize_size, resize_size), antialias=True),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
def _normalize_image_input(image_or_images) -> Tuple[bool, list]:
"""Returns (was_single, [images]). Accepts a PIL.Image or an iterable of them."""
if isinstance(image_or_images, Image.Image):
return True, [image_or_images]
images = list(image_or_images)
if not images:
raise ValueError("empty image list")
for i, img in enumerate(images):
if not isinstance(img, Image.Image):
raise TypeError(f"images[{i}] is {type(img).__name__}, expected PIL.Image")
return False, images
class _BackboneExportWrapper(nn.Module):
"""ONNX-friendly wrapper: returns (cls, spatial) instead of a dict."""
def __init__(self, backbone: nn.Module):
super().__init__()
self.backbone = backbone
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
out = self.backbone.forward_features(x)
cls = out["x_norm_clstoken"]
patches = out["x_norm_patchtokens"]
B, N, D = patches.shape
h = w = int(N ** 0.5)
spatial = patches.permute(0, 2, 1).reshape(B, D, h, w)
return cls, spatial
class _SegHeadExportWrapper(nn.Module):
"""ONNX-friendly wrapper: seg head + bilinear upsample to input resolution.
The bare seg head emits stride-16 logits (e.g. [B, 150, 40, 40] at 640px
input). model.segment() upsamples those to the input resolution before
argmax. This wrapper folds the upsample into the graph so the ONNX seg
output is already at input resolution — consumers argmax directly without
a separate interpolation step.
"""
def __init__(self, seg_head: nn.Module, resolution: int):
super().__init__()
self.seg_head = seg_head
self.resolution = resolution
def forward(self, spatial_features: Tensor) -> Tensor:
logits = self.seg_head(spatial_features)
return F.interpolate(logits, size=(self.resolution, self.resolution),
mode="bilinear", align_corners=False)
class _DepthHeadExportWrapper(nn.Module):
"""ONNX-friendly wrapper for the DPT depth head.
DPTDepthDecoder.forward takes (intermediates: List[Tensor], H: int, W: int),
which torch.onnx.export cannot trace cleanly because the List contains four
tensors and H/W are Python ints. The wrapper accepts the four intermediate
ViT-block activations as separate positional tensor inputs and forwards them
to the underlying decoder with the captured H and W.
"""
def __init__(self, depth_head: nn.Module, H: int, W: int):
super().__init__()
self.depth_head = depth_head
self.H = H
self.W = W
def forward(self, inter0: Tensor, inter1: Tensor, inter2: Tensor, inter3: Tensor) -> Tensor:
return self.depth_head([inter0, inter1, inter2, inter3], self.H, self.W)
class _ClassifierExportWrapper(nn.Module):
"""ONNX-friendly wrapper for the ImageNet linear-softmax classifier.
Takes the backbone's CLS token, L2-normalizes, applies the stored
Linear(embed_dim, 1000) weight + bias, and returns a softmax
distribution over the 1000 ImageNet classes. The weight and bias are
captured as buffers so the graph is self-contained — no separate
weight file needed for classification inference.
"""
def __init__(self, class_weight: Tensor, class_bias: Tensor):
super().__init__()
self.register_buffer("weight", class_weight.float().clone())
self.register_buffer("bias", class_bias.float().clone())
def forward(self, cls_token: Tensor) -> Tensor:
x = F.normalize(cls_token, dim=-1)
logits = F.linear(x, self.weight, self.bias)
return F.softmax(logits, dim=-1)
class _ONNXBatchedNMS(torch.autograd.Function):
"""Autograd wrapper that exports to ONNX NonMaxSuppression (opset >= 10).
ONNX's NonMaxSuppression handles batched multi-class NMS natively:
boxes [B, N, 4] in [y1, x1, y2, x2] order (center_point_box=0)
scores [B, C, N]
-> selected_indices [M, 3] where each row is [batch, class, box]
The eager forward path reproduces this via torchvision.ops.nms so
PyTorch tracing and verify=True both work without calling into
ORT for the reference.
"""
@staticmethod
def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
return g.op(
"NonMaxSuppression",
boxes, scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
center_point_box_i=0,
)
@staticmethod
def forward(ctx, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold):
from torchvision.ops import nms as tv_nms
B, N, _ = boxes.shape
_, C, _ = scores.shape
max_out = int(max_output_boxes_per_class.item())
iou_thr = float(iou_threshold.item())
score_thr = float(score_threshold.item())
results: List[List[int]] = []
for b in range(B):
for c in range(C):
sc = scores[b, c]
mask = sc > score_thr
if not mask.any():
continue
idx = mask.nonzero(as_tuple=True)[0]
# tv_nms expects [x1, y1, x2, y2]; our boxes are [y1, x1, y2, x2].
bx_xyxy = boxes[b, idx][:, [1, 0, 3, 2]]
keep = tv_nms(bx_xyxy, sc[idx], iou_thr)[:max_out]
for k in keep.tolist():
results.append([b, c, int(idx[k].item())])
if not results:
return torch.zeros((0, 3), dtype=torch.long, device=boxes.device)
return torch.tensor(results, dtype=torch.long, device=boxes.device)
class _DetectionHeadExportWrapper(nn.Module):
"""ONNX-friendly wrapper for the detection head (simple FPN + FCOS).
Takes backbone stride-16 spatial features and returns decoded
per-location predictions concatenated across all five FPN levels.
Without NMS (default):
- boxes [B, N_total, 4] xyxy in input-resolution pixels,
decoded as (location - exp(reg)) /
(location + exp(reg)) and clamped.
- scores [B, N_total, num_classes]
sigmoid(cls_logits) * sigmoid(centerness).
With NMS (include_nms=True):
- boxes [M, 4] xyxy in input-resolution pixels
- scores [M]
- class_labels [M] int64 class index
- batch_indices[M] int64 batch index
N_total = sum(H_i * W_i) across strides [8, 16, 32, 64, 128]. At
640px input: 6400 + 1600 + 400 + 100 + 25 = 8525 locations/image.
The NMS variant folds ONNX's NonMaxSuppression (opset >= 10) into
the graph using the configured iou / score / max_detections
parameters, producing a flat list of surviving detections across
all batches and classes. Useful for single-shot TensorRT / mobile
inference. Without NMS the consumer runs their own — hard vs soft,
per-class vs global, threshold tuning — without re-exporting.
"""
def __init__(self, detection_head: nn.Module, resolution: int,
include_nms: bool = False,
nms_iou_threshold: float = 0.5,
nms_score_threshold: float = 0.05,
nms_max_detections: int = 100):
super().__init__()
self.detection_head = detection_head
self.resolution = resolution
self.num_classes = detection_head.num_classes
self.include_nms = include_nms
self.nms_iou_threshold = nms_iou_threshold
self.nms_score_threshold = nms_score_threshold
self.nms_max_detections = nms_max_detections
# Compute per-level spatial sizes from the SimpleFeaturePyramid's actual
# output shapes, not from resolution // stride. The pyramid starts at
# stride-16 backbone features (H = resolution // 16) and produces:
# P3 = 2*H via ConvTranspose2d(stride=2)
# P4 = H via 1x1 + 3x3 convs (no stride)
# P5 = (H+1)//2 via Conv2d(3x3, stride=2, padding=1)
# P6 = (P5+1)//2 via Conv2d on P5
# P7 = (P6+1)//2 via Conv2d on P6
# When resolution is a multiple of 128, these match resolution // stride
# exactly; at other resolutions the stride-2 convs round up via the
# padding=1 kernel=3 formula, so P6/P7 are slightly larger than
# nominal stride division suggests. Feature-pyramid-level locations
# still use the nominal FPN_STRIDES for FCOS box decoding because
# that's what eager `model.detect` does.
H = resolution // 16
p3 = 2 * H
p4 = H
p5 = (H + 1) // 2
p6 = (p5 + 1) // 2
p7 = (p6 + 1) // 2
feat_sizes = [(p3, p3), (p4, p4), (p5, p5), (p6, p6), (p7, p7)]
locs_per_level = []
for (h, w), s in zip(feat_sizes, FPN_STRIDES):
ys = (torch.arange(h, dtype=torch.float32) + 0.5) * s
xs = (torch.arange(w, dtype=torch.float32) + 0.5) * s
gy, gx = torch.meshgrid(ys, xs, indexing="ij")
locs_per_level.append(torch.stack([gx.flatten(), gy.flatten()], -1))
all_locs = torch.cat(locs_per_level, 0)
self.register_buffer("all_locs", all_locs)
def forward(self, spatial_features: Tensor):
cls_logits, box_regs, centernesses = self.detection_head(spatial_features)
B = spatial_features.shape[0]
flat_cls = torch.cat(
[c.permute(0, 2, 3, 1).reshape(B, -1, self.num_classes) for c in cls_logits], dim=1)
flat_reg = torch.cat(
[r.permute(0, 2, 3, 1).reshape(B, -1, 4) for r in box_regs], dim=1)
flat_ctr = torch.cat(
[c.permute(0, 2, 3, 1).reshape(B, -1, 1) for c in centernesses], dim=1)
scores = torch.sigmoid(flat_cls) * torch.sigmoid(flat_ctr)
locs = self.all_locs.unsqueeze(0).expand(B, -1, -1)
x1 = (locs[..., 0:1] - flat_reg[..., 0:1]).clamp(0, self.resolution)
y1 = (locs[..., 1:2] - flat_reg[..., 1:2]).clamp(0, self.resolution)
x2 = (locs[..., 0:1] + flat_reg[..., 2:3]).clamp(0, self.resolution)
y2 = (locs[..., 1:2] + flat_reg[..., 3:4]).clamp(0, self.resolution)
boxes = torch.cat([x1, y1, x2, y2], dim=-1)
if not self.include_nms:
return boxes, scores
# ONNX NMS expects boxes in [y1, x1, y2, x2] (center_point_box=0) and
# scores with the class dim in the middle: [B, C, N].
boxes_yxyx = torch.cat([y1, x1, y2, x2], dim=-1)
scores_bcn = scores.permute(0, 2, 1).contiguous()
max_out = torch.tensor(self.nms_max_detections, dtype=torch.long, device=boxes.device)
iou_thr = torch.tensor(self.nms_iou_threshold, dtype=torch.float32, device=boxes.device)
score_thr = torch.tensor(self.nms_score_threshold, dtype=torch.float32, device=boxes.device)
selected = _ONNXBatchedNMS.apply(
boxes_yxyx, scores_bcn, max_out, iou_thr, score_thr,
)
batch_idx = selected[:, 0].long()
class_idx = selected[:, 1].long()
box_idx = selected[:, 2].long()
sel_boxes = boxes[batch_idx, box_idx] # [M, 4] xyxy
sel_scores = scores[batch_idx, box_idx, class_idx] # [M]
return sel_boxes, sel_scores, class_idx, batch_idx
class SegmentationHead(nn.Module):
def __init__(self, in_dim: int = 768, num_classes: int = 150):
super().__init__()
self.batchnorm_layer = nn.BatchNorm2d(in_dim)
self.conv = nn.Conv2d(in_dim, num_classes, kernel_size=1)
def forward(self, x: Tensor) -> Tensor:
return self.conv(self.batchnorm_layer(x))
class DepthHead(nn.Module):
def __init__(self, in_dim: int = 768, n_bins: int = 256,
min_depth: float = 0.001, max_depth: float = 10.0):
super().__init__()
self.batchnorm_layer = nn.BatchNorm2d(in_dim)
self.conv_depth = nn.Conv2d(in_dim, n_bins, kernel_size=1)
self.min_depth = min_depth
self.max_depth = max_depth
self.n_bins = n_bins
def forward(self, x: Tensor) -> Tensor:
logits = self.conv_depth(self.batchnorm_layer(x))
logit = torch.relu(logits) + 0.1
logit = logit / logit.sum(dim=1, keepdim=True)
bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=x.device)
return torch.einsum("bkhw,k->bhw", logit, bins).unsqueeze(1)
# ===========================================================================
# Detection (FCOS with ViTDet-style simple feature pyramid)
# ===========================================================================
FPN_STRIDES = [8, 16, 32, 64, 128]
COCO_CLASSES = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
"boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
"bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra",
"giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee",
"skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove",
"skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
"fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
"broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch",
"potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse",
"remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink",
"refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier",
"toothbrush",
]
def cofiber_decompose(f: Tensor, n_scales: int) -> List[Tensor]:
"""Iterated multi-scale decomposition. Each step subtracts the
downsampled-then-upsampled component of the current residual and
recurses on the remainder. Zero learned parameters. The final entry is
the lowest-frequency remainder."""
cofibers: List[Tensor] = []
residual = f
for _ in range(n_scales - 1):
omega = F.avg_pool2d(residual, 2)
sigma_omega = F.interpolate(omega, size=residual.shape[2:],
mode="bilinear", align_corners=False)
cofibers.append(residual - sigma_omega)
residual = omega
cofibers.append(residual)
return cofibers
def make_sin_pos_emb(H: int, W: int, dim: int, device) -> Tensor:
"""2D sinusoidal positional encoding over an H x W grid. Concatenated
to the backbone patch features before the head stem."""
assert dim % 4 == 0, "pos emb dim must be divisible by 4"
d = dim // 4
ys = torch.arange(H, device=device, dtype=torch.float32)
xs = torch.arange(W, device=device, dtype=torch.float32)
omega = torch.exp(torch.arange(d, device=device, dtype=torch.float32)
* -(math.log(10000.0) / d))
pe_y = torch.zeros(H, d * 2, device=device)
pe_y[:, 0::2] = torch.sin(ys[:, None] * omega[None, :])
pe_y[:, 1::2] = torch.cos(ys[:, None] * omega[None, :])
pe_x = torch.zeros(W, d * 2, device=device)
pe_x[:, 0::2] = torch.sin(xs[:, None] * omega[None, :])
pe_x[:, 1::2] = torch.cos(xs[:, None] * omega[None, :])
pos = torch.zeros(dim, H, W, device=device)
pos[:d * 2] = pe_y.permute(1, 0)[:, :, None].expand(-1, H, W)
pos[d * 2:] = pe_x.permute(1, 0)[None, :, :].expand(H, -1, W).permute(1, 0, 2)
return pos.unsqueeze(0)
class ConvGNBlock(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.conv = nn.Conv2d(channels, channels, 3, padding=1)
self.norm = nn.GroupNorm(min(32, channels), channels)
self.act = nn.GELU()
def forward(self, x: Tensor) -> Tensor:
return self.act(self.norm(self.conv(x)))
class DWResBlock(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.pw = nn.Conv2d(channels, channels, 1)
self.act = nn.GELU()
self.dw = nn.Conv2d(channels, channels, 3, padding=1, groups=channels)
self.norm = nn.GroupNorm(min(32, channels), channels)
def forward(self, x: Tensor) -> Tensor:
return x + self.norm(self.dw(self.act(self.pw(x))))
def make_tower(hidden: int, n_std: int, n_dw: int) -> nn.Sequential:
layers: List[nn.Module] = [ConvGNBlock(hidden) for _ in range(n_std)]
layers += [DWResBlock(hidden) for _ in range(n_dw)]
return nn.Sequential(*layers)
class SplitTowerHead(nn.Module):
"""Detection head operating on a cofiber decomposition of the frozen
backbone features. Five prediction levels (strides 8, 16, 32, 64, 128):
a stride-8 level synthesized by a transposed convolution from the
stride-16 band and four cofiber bands at strides 16, 32, 64, 128.
Separate classification and regression towers of depth (n_std_layers +
n_dw_layers) with weights shared across levels. Classification via
cosine similarity against frozen CLIP text-encoder embeddings of the
COCO class names; regression via exponentiated LTRB distances with a
learned per-level scale; centerness via a single 1x1 convolution.
Inference-only within Argus: no DFL, no IoU-aware branch, no
per-scale bias. The text_embed buffer is populated by from_pretrained's
state_dict load."""
def __init__(self,
feat_dim: int = 768,
hidden: int = 160,
n_std_layers: int = 5,
n_dw_layers: int = 4,
n_scales: int = 4,
pos_emb_dim: int = 64,
num_classes: int = 80,
text_embed_dim: int = 768):
super().__init__()
self.n_scales = n_scales
self.pos_emb_dim = pos_emb_dim
self.num_classes = num_classes
self.text_embed_dim = text_embed_dim
n_total = n_scales + 1
input_dim = feat_dim + pos_emb_dim
self.scale_norms = nn.ModuleList([nn.GroupNorm(1, input_dim) for _ in range(n_scales)])
self.stem = nn.Conv2d(input_dim, hidden, 1)
self.stem_act = nn.GELU()
self.p3_upsample = nn.ConvTranspose2d(hidden, hidden, 2, stride=2)
self.p3_norm = nn.GroupNorm(min(32, hidden), hidden)
self.lateral_convs = nn.ModuleList([nn.Conv2d(hidden, hidden, 1) for _ in range(n_scales - 1)])
self.lateral_norms = nn.ModuleList(
[nn.GroupNorm(min(32, hidden), hidden) for _ in range(n_scales - 1)])
self.cls_tower = make_tower(hidden, n_std_layers, n_dw_layers)
self.reg_tower = make_tower(hidden, n_std_layers, n_dw_layers)
# CLIP text-aligned classifier. The text_embed buffer is filled from
# the state dict at from_pretrained; the zero placeholder here only
# exists so the module can be constructed before weights arrive.
self.register_buffer("text_embed",
torch.zeros(num_classes, text_embed_dim))
self.cls_project = nn.Linear(hidden, text_embed_dim, bias=False)
self.logit_scale = nn.Parameter(torch.tensor(math.log(1.0 / 0.07)))
self.cls_bias = nn.Parameter(torch.full((num_classes,), -math.log(99)))
self.reg_pred = nn.Conv2d(hidden, 4, 1)
self.ctr_pred = nn.Conv2d(hidden, 1, 1)
self.scale_params = nn.Parameter(torch.ones(n_total))
def forward(self, spatial: Tensor) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]:
B, C, H_, W_ = spatial.shape
pos = make_sin_pos_emb(H_, W_, self.pos_emb_dim, spatial.device).expand(B, -1, -1, -1)
spatial = torch.cat([spatial, pos], dim=1)
cofibers = cofiber_decompose(spatial, self.n_scales)
scale_features: List[Tensor] = []
for i, cof in enumerate(cofibers):
x = self.stem_act(self.stem(self.scale_norms[i](cof)))
scale_features.append(x)
# Top-down lateral fusion from coarser to finer scales.
for i in range(len(scale_features) - 2, -1, -1):
coarse_up = F.interpolate(scale_features[i + 1],
size=scale_features[i].shape[2:],
mode="bilinear", align_corners=False)
scale_features[i] = self.lateral_norms[i](
scale_features[i] + self.lateral_convs[i](coarse_up))
p3 = self.p3_norm(self.p3_upsample(scale_features[0]))
all_features = [p3] + scale_features
cls_l, reg_l, ctr_l = [], [], []
for i, x in enumerate(all_features):
cls_feat = self.cls_tower(x)
reg_feat = self.reg_tower(x)
B_, _, Hi, Wi = cls_feat.shape
f = cls_feat.permute(0, 2, 3, 1).reshape(-1, cls_feat.shape[1])
f_proj = self.cls_project(f)
f_norm = F.normalize(f_proj, p=2, dim=-1)
logits = f_norm @ self.text_embed.t()
cls = (logits * self.logit_scale.exp() + self.cls_bias).reshape(
B_, Hi, Wi, self.num_classes).permute(0, 3, 1, 2)
reg_raw = (self.reg_pred(reg_feat) * self.scale_params[i]).clamp(-10, 10)
reg = reg_raw.exp()
ctr = self.ctr_pred(reg_feat)
cls_l.append(cls)
reg_l.append(reg)
ctr_l.append(ctr)
return cls_l, reg_l, ctr_l
def _make_locations(feature_sizes: List[Tuple[int, int]], strides: List[int], device) -> List[Tensor]:
"""Per-level center coordinates of feature-map locations in image space."""
all_locs = []
for (h, w), s in zip(feature_sizes, strides):
ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s
xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s
grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij")
locs = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1)
all_locs.append(locs)
return all_locs
@torch.inference_mode()
def _decode_detections(
cls_logits_per_level: List[Tensor],
box_regs_per_level: List[Tensor],
centernesses_per_level: List[Tensor],
locations_per_level: List[Tensor],
image_sizes: List[Tuple[int, int]],
score_thresh: float = 0.05,
nms_thresh: float = 0.5,
max_per_level: int = 1000,
max_per_image: int = 100,
) -> List[Dict[str, Tensor]]:
"""Convert per-level logits/regs/centerness into per-image detections (xyxy boxes)."""
B = cls_logits_per_level[0].shape[0]
num_classes = cls_logits_per_level[0].shape[1]
device = cls_logits_per_level[0].device
per_image_results = []
for image_idx in range(B):
all_boxes, all_scores, all_labels = [], [], []
for cls_l, reg_l, ctr_l, locs_l in zip(
cls_logits_per_level, box_regs_per_level, centernesses_per_level, locations_per_level
):
cls = cls_l[image_idx].permute(1, 2, 0).reshape(-1, num_classes)
reg = reg_l[image_idx].permute(1, 2, 0).reshape(-1, 4)
ctr = ctr_l[image_idx].permute(1, 2, 0).reshape(-1)
cls_prob = torch.sigmoid(cls)
ctr_prob = torch.sigmoid(ctr)
scores = cls_prob * ctr_prob[:, None]
mask = scores > score_thresh
if not mask.any():
continue
cand_loc, cand_cls = mask.nonzero(as_tuple=True)
cand_scores = scores[cand_loc, cand_cls]
if cand_scores.numel() > max_per_level:
top = cand_scores.topk(max_per_level)
cand_scores = top.values
idx = top.indices
cand_loc = cand_loc[idx]
cand_cls = cand_cls[idx]
cand_locs_xy = locs_l[cand_loc]
cand_reg = reg[cand_loc]
boxes = torch.stack([
cand_locs_xy[:, 0] - cand_reg[:, 0],
cand_locs_xy[:, 1] - cand_reg[:, 1],
cand_locs_xy[:, 0] + cand_reg[:, 2],
cand_locs_xy[:, 1] + cand_reg[:, 3],
], dim=-1)
all_boxes.append(boxes)
all_scores.append(cand_scores)
all_labels.append(cand_cls)
if all_boxes:
boxes = torch.cat(all_boxes, dim=0)
scores = torch.cat(all_scores, dim=0)
labels = torch.cat(all_labels, dim=0)
H, W = image_sizes[image_idx]
boxes[:, 0::2] = boxes[:, 0::2].clamp(0, W)
boxes[:, 1::2] = boxes[:, 1::2].clamp(0, H)
keep_all = []
for c in labels.unique():
cm = labels == c
keep = nms(boxes[cm], scores[cm], nms_thresh)
keep_idx = cm.nonzero(as_tuple=True)[0][keep]
keep_all.append(keep_idx)
keep_all = torch.cat(keep_all, dim=0)
boxes = boxes[keep_all]
scores = scores[keep_all]
labels = labels[keep_all]
if scores.numel() > max_per_image:
top = scores.topk(max_per_image)
boxes = boxes[top.indices]
scores = top.values
labels = labels[top.indices]
else:
boxes = torch.zeros((0, 4), device=device)
scores = torch.zeros((0,), device=device)
labels = torch.zeros((0,), dtype=torch.long, device=device)
per_image_results.append({"boxes": boxes, "scores": scores, "labels": labels})
return per_image_results
def _letterbox_to_square(image: Image.Image, resolution: int) -> Tuple[Image.Image, float, Tuple[int, int]]:
"""Resize preserving aspect ratio and pad bottom/right with black. Matches the training transform."""
W0, H0 = image.size
scale = resolution / max(H0, W0)
new_w = int(round(W0 * scale))
new_h = int(round(H0 * scale))
resized = image.resize((new_w, new_h), Image.BILINEAR)
canvas = Image.new("RGB", (resolution, resolution), (0, 0, 0))
canvas.paste(resized, (0, 0))
return canvas, scale, (W0, H0)
# ===========================================================================
# DPT depth decoder (multi-scale, hooks into ViT blocks [2, 5, 8, 11])
# ===========================================================================
HOOK_BLOCK_INDICES = [2, 5, 8, 11]
N_PREFIX_TOKENS = 5 # 1 CLS + 4 register/storage tokens
class _ResidualConvUnit(nn.Module):
"""Two 3x3 conv + BatchNorm blocks with a residual connection. Padding
mode is configurable: the Argus-B DPT depth head trains with reflect
padding to avoid edge artifacts; Argus-Lite ships weights that were
trained with zero padding (the PyTorch default), and switching pad
modes at inference would create a small distribution shift in the
edge regions. Variants pass `padding_mode` to keep their inference
aligned with their training."""
def __init__(self, dim: int, padding_mode: str = "reflect"):
super().__init__()
self.conv1 = nn.Conv2d(dim, dim, 3, padding=1, padding_mode=padding_mode, bias=False)
self.bn1 = nn.BatchNorm2d(dim)
self.conv2 = nn.Conv2d(dim, dim, 3, padding=1, padding_mode=padding_mode, bias=False)
self.bn2 = nn.BatchNorm2d(dim)
self.act = nn.GELU()
def forward(self, x: Tensor) -> Tensor:
return x + self.bn2(self.conv2(self.act(self.bn1(self.conv1(x)))))
class _FeatureFusionBlock(nn.Module):
def __init__(self, dim: int, has_skip: bool = True, padding_mode: str = "reflect"):
super().__init__()
self.rcu1 = _ResidualConvUnit(dim, padding_mode=padding_mode)
self.rcu2 = _ResidualConvUnit(dim, padding_mode=padding_mode)
self.skip_proj = nn.Conv2d(dim, dim, 1) if has_skip else None
def forward(self, x: Tensor, skip: Optional[Tensor] = None) -> Tensor:
if skip is not None and self.skip_proj is not None:
x = x + self.skip_proj(skip)
x = self.rcu1(x)
x = self.rcu2(x)
return F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
class _DPTReassemble(nn.Module):
def __init__(self, in_dim: int = 768, out_dim: int = 256):
super().__init__()
self.projects = nn.ModuleList([
nn.Sequential(nn.LayerNorm(in_dim), nn.Linear(in_dim, out_dim))
for _ in range(4)
])
self.refine = nn.ModuleList([
nn.Sequential(
nn.Conv2d(out_dim, out_dim, 3, padding=1, padding_mode="reflect", bias=False),
nn.BatchNorm2d(out_dim),
nn.GELU(),
)
for _ in range(4)
])
def forward(self, intermediates: List[Tensor], H: int, W: int) -> List[Tensor]:
out = []
for feat, proj, refine in zip(intermediates, self.projects, self.refine):
patches = feat[:, N_PREFIX_TOKENS:, :]
patches = proj(patches)
B, N, D = patches.shape
spatial = patches.permute(0, 2, 1).reshape(B, D, H, W)
out.append(refine(spatial))
level_4 = F.interpolate(out[0], scale_factor=4, mode="bilinear", align_corners=False)
level_8 = F.interpolate(out[1], scale_factor=2, mode="bilinear", align_corners=False)
level_16 = out[2]
level_32 = F.interpolate(out[3], scale_factor=0.5, mode="bilinear", align_corners=False)
return [level_4, level_8, level_16, level_32]
class DPTDepthDecoder(nn.Module):
def __init__(self, in_dim: int = 768, decoder_dim: int = 256,
n_bins: int = 256, min_depth: float = 0.001, max_depth: float = 10.0):
super().__init__()
self.n_bins = n_bins
self.min_depth = min_depth
self.max_depth = max_depth
self.reassemble = _DPTReassemble(in_dim=in_dim, out_dim=decoder_dim)
self.fusion_blocks = nn.ModuleList([
_FeatureFusionBlock(decoder_dim, has_skip=True),
_FeatureFusionBlock(decoder_dim, has_skip=True),
_FeatureFusionBlock(decoder_dim, has_skip=True),
_FeatureFusionBlock(decoder_dim, has_skip=False),
])
self.head = nn.Sequential(
nn.Conv2d(decoder_dim, decoder_dim, 3, padding=1, padding_mode="reflect", bias=False),
nn.BatchNorm2d(decoder_dim),
nn.GELU(),
nn.Conv2d(decoder_dim, n_bins, 1),
)
def forward(self, intermediates: List[Tensor], H: int, W: int,
return_distribution: bool = False):
levels = self.reassemble(intermediates, H, W)
x = self.fusion_blocks[3](levels[3])
x = self.fusion_blocks[2](x, skip=levels[2])
x = self.fusion_blocks[1](x, skip=levels[1])
x = self.fusion_blocks[0](x, skip=levels[0])
logits = self.head(x)
distribution = torch.relu(logits) + 0.1
distribution = distribution / distribution.sum(dim=1, keepdim=True)
bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=x.device)
depth = torch.einsum("bkhw,k->bhw", distribution, bins).unsqueeze(1)
if return_distribution:
return depth, distribution, bins
return depth
# ===========================================================================
# Argus model (transformers-compatible)
# ===========================================================================
class ArgusConfig(PretrainedConfig):
model_type = "argus"
def __init__(
self,
embed_dim: int = 768,
patch_size: int = 16,
num_seg_classes: int = 150,
depth_n_bins: int = 256,
depth_min_depth: float = 0.001,
depth_max_depth: float = 10.0,
num_imagenet_classes: int = 1000,
class_ids: Optional[list] = None,
class_names: Optional[list] = None,
detection_num_classes: int = 80,
detection_hidden: int = 160,
detection_n_std_layers: int = 5,
detection_n_dw_layers: int = 4,
detection_n_scales: int = 4,
detection_pos_emb_dim: int = 64,
detection_text_embed_dim: int = 768,
detection_class_names: Optional[list] = None,
**kwargs,
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.patch_size = patch_size
self.num_seg_classes = num_seg_classes
self.depth_n_bins = depth_n_bins
self.depth_min_depth = depth_min_depth
self.depth_max_depth = depth_max_depth
self.num_imagenet_classes = num_imagenet_classes
self.class_ids = class_ids or []
self.class_names = class_names or []
self.detection_num_classes = detection_num_classes
self.detection_hidden = detection_hidden
self.detection_n_std_layers = detection_n_std_layers
self.detection_n_dw_layers = detection_n_dw_layers
self.detection_n_scales = detection_n_scales
self.detection_pos_emb_dim = detection_pos_emb_dim
self.detection_text_embed_dim = detection_text_embed_dim
self.detection_class_names = detection_class_names or list(COCO_CLASSES)
class Argus(PreTrainedModel):
config_class = ArgusConfig
base_model_prefix = "argus"
supports_gradient_checkpointing = False
_tied_weights_keys: list = []
all_tied_weights_keys: dict = {}
def __init__(self, config: ArgusConfig):
super().__init__(config)
self.backbone = build_eupe_vitb16()
self.seg_head = SegmentationHead(config.embed_dim, config.num_seg_classes)
self.depth_head = DPTDepthDecoder(
in_dim=config.embed_dim,
decoder_dim=256,
n_bins=config.depth_n_bins,
min_depth=config.depth_min_depth,
max_depth=config.depth_max_depth,
)
self.register_buffer(
"class_logit_weight",
torch.zeros(config.num_imagenet_classes, config.embed_dim),
persistent=True,
)
self.register_buffer(
"class_logit_bias",
torch.zeros(config.num_imagenet_classes),
persistent=True,
)
self.detection_head = SplitTowerHead(
feat_dim=config.embed_dim,
hidden=config.detection_hidden,
n_std_layers=config.detection_n_std_layers,
n_dw_layers=config.detection_n_dw_layers,
n_scales=config.detection_n_scales,
pos_emb_dim=config.detection_pos_emb_dim,
num_classes=config.detection_num_classes,
text_embed_dim=config.detection_text_embed_dim,
)
for p in self.backbone.parameters():
p.requires_grad = False
self.backbone.eval()
self.seg_head.eval()
self.depth_head.eval()
self.detection_head.eval()
def _init_weights(self, module):
# HF reallocates missing buffers and parameters with torch.empty()
# (uninitialized memory) on from_pretrained. Populate sensible defaults
# for the standard layer types used by the detection head, and zero any
# Argus-level buffer that came back NaN.
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.GroupNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
if module is self:
for name in ("class_logit_weight", "class_logit_bias"):
if hasattr(self, name):
buf = getattr(self, name)
if torch.isnan(buf).any() or torch.isinf(buf).any():
buf.data.zero_()
def _load_imagenet_classes(self):
if getattr(self, "_imagenet_classes_loaded", False):
return
self._imagenet_classes_loaded = True
import json
import os as _os
candidates = []
here = _os.path.dirname(_os.path.abspath(__file__))
candidates.append(_os.path.join(here, "imagenet_classes.json"))
name_or_path = getattr(self.config, "_name_or_path", None)
if name_or_path and _os.path.isdir(name_or_path):
candidates.append(_os.path.join(name_or_path, "imagenet_classes.json"))
for path in candidates:
if _os.path.isfile(path):
with open(path) as f:
data = json.load(f)
self.config.class_ids = data.get("class_ids", [])
self.config.class_names = data.get("class_names", [])
return
if name_or_path and not _os.path.isdir(name_or_path):
try:
from huggingface_hub import hf_hub_download
path = hf_hub_download(name_or_path, "imagenet_classes.json")
with open(path) as f:
data = json.load(f)
self.config.class_ids = data.get("class_ids", [])
self.config.class_names = data.get("class_names", [])
except Exception:
pass
@property
def class_ids(self):
if not self.config.class_ids:
self._load_imagenet_classes()
return self.config.class_ids
@property
def class_names(self):
if not self.config.class_names:
self._load_imagenet_classes()
return self.config.class_names
def quantize_int8(self):
"""Apply INT8 weight-only quantization via torchao. Reduces VRAM by ~11%
with negligible accuracy loss (<0.05 m depth drift, 100% classification
agreement). Requires torchao: pip install torchao."""
try:
from torchao.quantization import quantize_, Int8WeightOnlyConfig
except ImportError as e:
raise ImportError("torchao is required for INT8 quantization: pip install torchao") from e
quantize_(self, Int8WeightOnlyConfig())
return self
@torch.inference_mode()
def _extract(self, image_tensor: Tensor) -> Tuple[Tensor, Tensor]:
with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"):
out = self.backbone.forward_features(image_tensor)
cls = out["x_norm_clstoken"].float()
patches = out["x_norm_patchtokens"].float()
B, N, D = patches.shape
h = w = int(N ** 0.5)
spatial = patches.permute(0, 2, 1).reshape(B, D, h, w)
return cls, spatial
@torch.inference_mode()
def classify(self, image_or_images, top_k: int = 5):
single, images = _normalize_image_input(image_or_images)
transform = make_eupe_transform(224)
batch = torch.stack([transform(img) for img in images]).to(self.device)
cls, _ = self._extract(batch)
cls = F.normalize(cls, dim=-1)
w = self.class_logit_weight.to(cls.dtype)
b = self.class_logit_bias.to(cls.dtype)
logits = F.linear(cls, w, b)
scores_full = F.softmax(logits, dim=-1)
topk = scores_full.topk(top_k, dim=-1)
top2 = scores_full.topk(2, dim=-1)
margins = (top2.values[:, 0] - top2.values[:, 1]).tolist()
results = []
for b in range(len(images)):
entries = []
for score, idx in zip(topk.values[b].tolist(), topk.indices[b].tolist()):
entries.append({
"class_id": self.class_ids[idx],
"class_name": self.class_names[idx],
"score": float(score),
})
entries[0]["margin"] = float(margins[b])
results.append(entries)
return results[0] if single else results
@torch.inference_mode()
def segment(self, image_or_images, resolution: int = 512, return_confidence: bool = False):
single, images = _normalize_image_input(image_or_images)
transform = make_eupe_transform(resolution)
batch = torch.stack([transform(img) for img in images]).to(self.device)
_, spatial = self._extract(batch)
with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"):
logits = self.seg_head(spatial)
logits = F.interpolate(logits, size=(resolution, resolution), mode="bilinear", align_corners=False)
seg_maps = logits.argmax(dim=1) # [B, H, W]
if return_confidence:
probs = F.softmax(logits.float(), dim=1)
conf_maps = probs.max(dim=1).values # [B, H, W] in [0, 1]
if single:
return seg_maps[0], conf_maps[0]
return [(seg_maps[i], conf_maps[i]) for i in range(len(images))]
if single:
return seg_maps[0]
return [seg_maps[i] for i in range(len(images))]
@torch.inference_mode()
def depth(self, image_or_images, resolution: int = 416, return_confidence: bool = False,
crop_border: bool = False):
"""Run the DPT depth decoder. Returns metric depth in meters at the
input resolution.
``crop_border=True`` strips a small border (``max(4, H/13)`` pixels per
side) from the raw decoder output before bilinear-upsampling to the
input resolution. Useful when this model is loaded with a backbone
whose DPT decoder was trained with zero padding (the unshipped
dev-fork behaviour), which leaves a systematic edge artifact. The
canonical checkpoint uses reflect padding inside every DPT conv and
does not need this crop, so the option defaults to ``False``."""
single, images = _normalize_image_input(image_or_images)
transform = make_eupe_transform(resolution)
batch = torch.stack([transform(img) for img in images]).to(self.device)
# Hook into intermediate ViT blocks for multi-scale features
intermediates = {}
hooks = []
for idx in HOOK_BLOCK_INDICES:
def _make_hook(block_idx):
def _hook(module, inp, out):
intermediates[block_idx] = out[0] if isinstance(out, list) else out
return _hook
hooks.append(self.backbone.blocks[idx].register_forward_hook(_make_hook(idx)))
with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"):
self.backbone.forward_features(batch)
for h in hooks:
h.remove()
inter_list = [intermediates[idx].float() for idx in HOOK_BLOCK_INDICES]
H = W = resolution // 16
if return_confidence:
depth_b, distribution, bins = self.depth_head(
inter_list, H, W, return_distribution=True)
# Std of the 256-bin depth distribution: var = E[X^2] - E[X]^2.
mean_sq = torch.einsum("bkhw,k->bhw", distribution, bins ** 2)
variance = (mean_sq - depth_b.squeeze(1) ** 2).clamp(min=0)
std_b = torch.sqrt(variance).unsqueeze(1)
else:
depth_b = self.depth_head(inter_list, H, W)
std_b = None
if crop_border:
crop = max(4, depth_b.shape[2] // 13)
depth_b = depth_b[:, :, crop:-crop, crop:-crop]
if std_b is not None:
std_b = std_b[:, :, crop:-crop, crop:-crop]
depth_b = F.interpolate(depth_b, size=(resolution, resolution), mode="bilinear", align_corners=False)
if std_b is not None:
std_b = F.interpolate(std_b, size=(resolution, resolution), mode="bilinear", align_corners=False)
depth_squeezed = depth_b[:, 0].float()
if return_confidence:
std_squeezed = std_b[:, 0].float()
if single:
return depth_squeezed[0], std_squeezed[0]
return [(depth_squeezed[i], std_squeezed[i]) for i in range(len(images))]
if single:
return depth_squeezed[0]
return [depth_squeezed[i] for i in range(len(images))]
@torch.inference_mode()
def correspond(
self,
src_image,
tgt_image,
resolution: int = 512,
):
"""Dense patch correspondence between two images.
Single-pair form: pass two `PIL.Image` instances. Returns a dict with
keys `matches` (numpy array of length grid*grid mapping each source
patch to its argmax target patch), `scores` (cosine similarity at the
match), and `grid` (the patch-grid side length).
Batched form: pass two equally-sized lists/iterables of images. Returns
a list of per-pair dicts in the same shape that a single call would
produce. Both lists are forwarded through the backbone in two
contiguous batches, so cross-pair throughput on GPU is much higher
than calling `correspond` in a loop.
"""
single = isinstance(src_image, Image.Image) and isinstance(tgt_image, Image.Image)
if single:
srcs = [src_image]
tgts = [tgt_image]
else:
srcs = list(src_image)
tgts = list(tgt_image)
if len(srcs) != len(tgts):
raise ValueError(
f"src_image and tgt_image must have the same length; "
f"got {len(srcs)} and {len(tgts)}")
if not srcs:
raise ValueError("empty image list")
for i, (a, b) in enumerate(zip(srcs, tgts)):
if not isinstance(a, Image.Image) or not isinstance(b, Image.Image):
raise TypeError(f"pair {i} must contain two PIL.Image instances")
transform = make_eupe_transform(resolution)
src_batch = torch.stack([transform(img) for img in srcs]).to(self.device)
tgt_batch = torch.stack([transform(img) for img in tgts]).to(self.device)
with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"):
oa = self.backbone.forward_features(src_batch)
ob = self.backbone.forward_features(tgt_batch)
pa_batch = F.normalize(oa['x_norm_patchtokens'].float(), dim=-1)
pb_batch = F.normalize(ob['x_norm_patchtokens'].float(), dim=-1)
results = []
for pa, pb in zip(pa_batch, pb_batch):
sim = pa @ pb.t()
m = sim.argmax(dim=-1)
s = sim.max(dim=-1).values
grid = int(np.sqrt(pa.shape[0]))
results.append({
"matches": m.cpu().numpy(),
"scores": s.cpu().numpy(),
"grid": grid,
})
return results[0] if single else results
@torch.inference_mode()
def detect(
self,
image_or_images,
resolution: int = 768,
score_thresh: float = 0.05,
nms_thresh: float = 0.5,
max_per_image: int = 100,
):
single, images = _normalize_image_input(image_or_images)
# Letterbox each image to match the training transform (resize long side
# to `resolution`, pad bottom/right with black). Box coordinates are
# recovered after decoding by unscaling.
canvases, scales, orig_sizes = [], [], []
for img in images:
canvas, scale, orig = _letterbox_to_square(img, resolution)
canvases.append(canvas)
scales.append(scale)
orig_sizes.append(orig)
det_normalize = v2.Compose([
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
batch = torch.stack([det_normalize(c) for c in canvases]).to(self.device)
_, spatial = self._extract(batch)
with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"):
cls_logits, box_regs, centernesses = self.detection_head(spatial)
cls_logits = [c.float() for c in cls_logits]
box_regs = [b.float() for b in box_regs]
centernesses = [c.float() for c in centernesses]
feature_sizes = [(cl.shape[2], cl.shape[3]) for cl in cls_logits]
locations = _make_locations(feature_sizes, FPN_STRIDES, spatial.device)
image_sizes = [(resolution, resolution)] * len(images)
results = _decode_detections(
cls_logits, box_regs, centernesses, locations,
image_sizes=image_sizes,
score_thresh=score_thresh,
nms_thresh=nms_thresh,
max_per_image=max_per_image,
)
class_names = self.config.detection_class_names
formatted = []
for i, r in enumerate(results):
scale = scales[i]
orig_w, orig_h = orig_sizes[i]
boxes = r["boxes"].cpu().numpy() / scale
boxes[:, 0::2] = boxes[:, 0::2].clip(0, orig_w)
boxes[:, 1::2] = boxes[:, 1::2].clip(0, orig_h)
detections = []
for box, score, label in zip(
boxes, r["scores"].cpu().numpy(), r["labels"].cpu().numpy()
):
detections.append({
"box": [float(v) for v in box.tolist()],
"score": float(score),
"label": int(label),
"class_name": class_names[int(label)] if int(label) < len(class_names) else f"class_{int(label)}",
})
formatted.append(detections)
return formatted[0] if single else formatted
def perceive(self, image_or_images, return_confidence: bool = False):
single, images = _normalize_image_input(image_or_images)
t0 = time.time()
classif = self.classify(images, top_k=5)
t1 = time.time()
seg_out = self.segment(images, resolution=512, return_confidence=return_confidence)
t2 = time.time()
depth_out = self.depth(images, resolution=416, return_confidence=return_confidence)
t3 = time.time()
if return_confidence:
seg_maps = [s for s, _ in seg_out]
seg_confs = [c for _, c in seg_out]
depth_maps = [d for d, _ in depth_out]
depth_uncerts = [u for _, u in depth_out]
else:
seg_maps = seg_out
depth_maps = depth_out
seg_confs = depth_uncerts = None
timings = {
"classify": (t1 - t0) * 1000,
"segment": (t2 - t1) * 1000,
"depth": (t3 - t2) * 1000,
"total": (t3 - t0) * 1000,
}
results = []
for i in range(len(images)):
entry = {
"classification": classif[i],
"segmentation": seg_maps[i].cpu().numpy(),
"depth": depth_maps[i].cpu().numpy(),
"timings_ms": timings,
}
if return_confidence:
entry["segmentation_confidence"] = seg_confs[i].cpu().numpy()
entry["depth_uncertainty"] = depth_uncerts[i].cpu().numpy()
results.append(entry)
return results[0] if single else results
def export_onnx(
self,
out_dir: str,
backbone_resolution: int = 224,
dynamic_batch: bool = True,
verify: bool = True,
tolerance: Union[float, Dict[str, float]] = 5e-2,
opset_version: int = 17,
include_nms: bool = False,
nms_iou_threshold: float = 0.5,
nms_score_threshold: float = 0.05,
nms_max_detections: int = 100,
) -> dict:
"""Export backbone, classifier, seg head, depth head, and detection head to ONNX.
Produces five graphs:
- argus_backbone.onnx image[B,3,H,W] -> cls[B,D], spatial[B,D,H/16,W/16]
- argus_classifier.onnx cls_token[B,D] -> probs[B,1000]
- argus_seg_head.onnx spatial_features[B,D,h,w] -> seg_logits[B,150,H,W]
- argus_depth_head.onnx intermediate_{0..3}[B,N+5,D] -> depth_map[B,1,~8h,~8w]
- argus_detection_head.onnx spatial_features[B,D,h,w] -> boxes, scores (+ labels, batch_indices if include_nms)
The seg graph folds bilinear upsample to input resolution into the
graph, so consumers argmax directly without a separate interpolation
step. Correspondence has no learned parameters — it runs as
cosine-max on the backbone's spatial output and needs no graph.
``include_nms=True`` bakes an ONNX NonMaxSuppression (opset >= 10)
op into the detection head. The detection graph then emits four
post-NMS tensors (boxes [M,4], scores [M], class_labels [M],
batch_indices [M]) instead of the raw (boxes, scores) pair. Useful
for single-shot TensorRT / mobile inference. The default
``include_nms=False`` leaves NMS to the consumer so they can choose
hard vs soft, per-class vs global, and tune thresholds without
re-exporting.
``tolerance`` can be a float (applied uniformly to every
``*_max_diff`` check) or a dict keyed by verification output name
(e.g. ``{"detection_boxes_max_diff": 3.2, "default": 5e-2}``). The
``"default"`` key covers outputs not otherwise listed. If a float
is passed, detection box coordinates get a resolution-scaled
tolerance (``max(tolerance, backbone_resolution * 5e-3)``) because
exp() in the FCOS regression path amplifies FP kernel-dispatch
differences to pixel-scale absolute diffs.
"""
import os
os.makedirs(out_dir, exist_ok=True)
if backbone_resolution % self.config.patch_size != 0:
raise ValueError(
f"backbone_resolution ({backbone_resolution}) must be a multiple of patch_size ({self.config.patch_size})"
)
spatial_resolution = backbone_resolution // self.config.patch_size
if backbone_resolution < 320:
import warnings
warnings.warn(
f"backbone_resolution={backbone_resolution} is below 320; the detection "
f"head's coarsest FPN level (stride 128) collapses to <=2 locations per "
f"side and the detection graph, while it exports and runs, cannot produce "
f"useful detections at this resolution. Classifier, seg, and depth graphs "
f"are unaffected. FCOS convention is 640-800px input; export at "
f">= 512 for detection.",
stacklevel=2,
)
wrapper = _BackboneExportWrapper(self.backbone).to(self.device).eval()
dummy_image = torch.randn(
1, 3, backbone_resolution, backbone_resolution,
device=self.device, dtype=torch.float32,
)
dummy_spatial = torch.randn(
1, self.config.embed_dim, spatial_resolution, spatial_resolution,
device=self.device, dtype=torch.float32,
)
backbone_path = os.path.join(out_dir, "argus_backbone.onnx")
classifier_path = os.path.join(out_dir, "argus_classifier.onnx")
seg_path = os.path.join(out_dir, "argus_seg_head.onnx")
depth_path = os.path.join(out_dir, "argus_depth_head.onnx")
detection_path = os.path.join(out_dir, "argus_detection_head.onnx")
backbone_axes = None
head_axes = None
if dynamic_batch:
backbone_axes = {
"image": {0: "batch"},
"cls_token": {0: "batch"},
"spatial_features": {0: "batch"},
}
head_axes = {
"spatial_features": {0: "batch"},
"seg_logits": {0: "batch"},
"depth_map": {0: "batch"},
}
# dynamo path crashes on EUPE's list-based forward; use legacy.
with torch.inference_mode():
torch.onnx.export(
wrapper, dummy_image, backbone_path,
input_names=["image"],
output_names=["cls_token", "spatial_features"],
dynamic_axes=backbone_axes,
opset_version=opset_version,
do_constant_folding=True,
dynamo=False,
)
seg_wrapper = _SegHeadExportWrapper(self.seg_head, backbone_resolution).to(self.device).eval()
torch.onnx.export(
seg_wrapper, dummy_spatial, seg_path,
input_names=["spatial_features"],
output_names=["seg_logits"],
dynamic_axes={"spatial_features": head_axes["spatial_features"], "seg_logits": head_axes["seg_logits"]} if head_axes else None,
opset_version=opset_version,
do_constant_folding=True,
dynamo=False,
)
depth_wrapper = _DepthHeadExportWrapper(
self.depth_head, spatial_resolution, spatial_resolution
).to(self.device).eval()
num_patch_tokens = spatial_resolution * spatial_resolution + N_PREFIX_TOKENS
dummy_inter = tuple(
torch.randn(1, num_patch_tokens, self.config.embed_dim,
device=self.device, dtype=torch.float32)
for _ in range(len(HOOK_BLOCK_INDICES))
)
depth_input_names = [f"intermediate_{i}" for i in range(len(HOOK_BLOCK_INDICES))]
if dynamic_batch:
depth_axes = {name: {0: "batch"} for name in depth_input_names}
depth_axes["depth_map"] = {0: "batch"}
else:
depth_axes = None
torch.onnx.export(
depth_wrapper, dummy_inter, depth_path,
input_names=depth_input_names,
output_names=["depth_map"],
dynamic_axes=depth_axes,
opset_version=opset_version,
do_constant_folding=True,
dynamo=False,
)
classifier_wrapper = _ClassifierExportWrapper(
self.class_logit_weight, self.class_logit_bias
).to(self.device).eval()
dummy_cls = torch.randn(
1, self.config.embed_dim, device=self.device, dtype=torch.float32,
)
if dynamic_batch:
classifier_axes = {"cls_token": {0: "batch"}, "class_probs": {0: "batch"}}
else:
classifier_axes = None
torch.onnx.export(
classifier_wrapper, dummy_cls, classifier_path,
input_names=["cls_token"],
output_names=["class_probs"],
dynamic_axes=classifier_axes,
opset_version=opset_version,
do_constant_folding=True,
dynamo=False,
)
detection_wrapper = _DetectionHeadExportWrapper(
self.detection_head, backbone_resolution,
include_nms=include_nms,
nms_iou_threshold=nms_iou_threshold,
nms_score_threshold=nms_score_threshold,
nms_max_detections=nms_max_detections,
).to(self.device).eval()
if include_nms:
detection_output_names = ["boxes", "scores", "class_labels", "batch_indices"]
# Post-NMS outputs are flat [M, ...]; no fixed batch axis to mark.
# Spatial features input still has a dynamic batch dim so the graph
# supports multi-image inference even with fused NMS.
detection_axes = {"spatial_features": {0: "batch"}} if dynamic_batch else None
else:
detection_output_names = ["boxes", "scores"]
if dynamic_batch:
detection_axes = {
"spatial_features": {0: "batch"},
"boxes": {0: "batch"},
"scores": {0: "batch"},
}
else:
detection_axes = None
torch.onnx.export(
detection_wrapper, dummy_spatial, detection_path,
input_names=["spatial_features"],
output_names=detection_output_names,
dynamic_axes=detection_axes,
opset_version=opset_version,
do_constant_folding=True,
dynamo=False,
)
result = {
"backbone": backbone_path,
"classifier": classifier_path,
"seg_head": seg_path,
"depth_head": depth_path,
"detection_head": detection_path,
}
if verify:
try:
import onnxruntime as ort
except ImportError as e:
raise ImportError("onnxruntime not installed; pip install onnxruntime") from e
providers = ["CPUExecutionProvider"]
verify_image = torch.randn(2, 3, backbone_resolution, backbone_resolution, dtype=torch.float32)
verify_spatial = torch.randn(2, self.config.embed_dim, spatial_resolution, spatial_resolution, dtype=torch.float32)
verify_cls = torch.randn(2, self.config.embed_dim, dtype=torch.float32)
verify_inter = [
torch.randn(2, num_patch_tokens, self.config.embed_dim, dtype=torch.float32)
for _ in range(len(HOOK_BLOCK_INDICES))
]
with torch.inference_mode():
ref_cls, ref_spatial = wrapper(verify_image.to(self.device))
ref_seg = seg_wrapper(verify_spatial.to(self.device))
ref_depth = depth_wrapper(*[v.to(self.device) for v in verify_inter])
ref_probs = classifier_wrapper(verify_cls.to(self.device))
ref_det = detection_wrapper(verify_spatial.to(self.device))
sess = ort.InferenceSession(backbone_path, providers=providers)
ort_cls, ort_spatial = sess.run(None, {"image": verify_image.numpy()})
cls_diff = float(np.abs(ort_cls - ref_cls.cpu().numpy()).max())
spatial_diff = float(np.abs(ort_spatial - ref_spatial.cpu().numpy()).max())
sess = ort.InferenceSession(seg_path, providers=providers)
ort_seg = sess.run(None, {"spatial_features": verify_spatial.numpy()})[0]
seg_diff = float(np.abs(ort_seg - ref_seg.cpu().numpy()).max())
sess = ort.InferenceSession(depth_path, providers=providers)
ort_depth = sess.run(None, {f"intermediate_{i}": verify_inter[i].numpy()
for i in range(len(HOOK_BLOCK_INDICES))})[0]
depth_diff = float(np.abs(ort_depth - ref_depth.cpu().numpy()).max())
sess = ort.InferenceSession(classifier_path, providers=providers)
ort_probs = sess.run(None, {"cls_token": verify_cls.numpy()})[0]
classifier_diff = float(np.abs(ort_probs - ref_probs.cpu().numpy()).max())
sess = ort.InferenceSession(detection_path, providers=providers)
ort_det = sess.run(None, {"spatial_features": verify_spatial.numpy()})
verification = {
"backbone_cls_max_diff": cls_diff,
"backbone_spatial_max_diff": spatial_diff,
"classifier_max_diff": classifier_diff,
"seg_head_max_diff": seg_diff,
"depth_head_max_diff": depth_diff,
"verified_batch_size": 2,
}
if include_nms:
# NMS is inherently implementation-dependent: ONNX's
# NonMaxSuppression and the torchvision eager fallback differ
# on tie-breaking when multiple detections share a score or
# when near-threshold boxes are right at the score cutoff.
# Element-wise comparison of post-NMS outputs is the wrong
# metric. The structural checks below verify the graph runs,
# returns reasonable shapes, and agrees on the top detection.
pt_boxes, pt_scores, pt_labels, _ = ref_det
ort_boxes, ort_scores, ort_labels, _ = ort_det
pt_n = int(pt_scores.shape[0])
ort_n = int(ort_scores.shape[0])
verification["detection_nms_ref_count"] = pt_n
verification["detection_nms_ort_count"] = ort_n
if pt_n > 0 and ort_n > 0:
pt_top = int(pt_scores.cpu().numpy().argmax())
ort_top = int(ort_scores.argmax())
pt_top_box = pt_boxes[pt_top].cpu().numpy()
ort_top_box = ort_boxes[ort_top]
# IoU of the two top boxes
x1 = max(pt_top_box[0], ort_top_box[0])
y1 = max(pt_top_box[1], ort_top_box[1])
x2 = min(pt_top_box[2], ort_top_box[2])
y2 = min(pt_top_box[3], ort_top_box[3])
inter = max(0.0, x2 - x1) * max(0.0, y2 - y1)
pt_area = max(0.0, pt_top_box[2] - pt_top_box[0]) * max(0.0, pt_top_box[3] - pt_top_box[1])
ort_area = max(0.0, ort_top_box[2] - ort_top_box[0]) * max(0.0, ort_top_box[3] - ort_top_box[1])
union = max(1e-6, pt_area + ort_area - inter)
verification["detection_nms_top_iou"] = float(inter / union)
verification["detection_nms_top_class_match"] = bool(
int(pt_labels[pt_top].cpu()) == int(ort_labels[ort_top])
)
verification["detection_nms_top_score_diff"] = float(abs(
float(pt_scores[pt_top].cpu()) - float(ort_scores[ort_top])
))
else:
verification["detection_nms_top_iou"] = None
verification["detection_nms_top_class_match"] = None
verification["detection_nms_top_score_diff"] = None
else:
ort_boxes, ort_scores = ort_det
ref_boxes, ref_scores = ref_det
verification["detection_boxes_max_diff"] = float(
np.abs(ort_boxes - ref_boxes.cpu().numpy()).max())
verification["detection_scores_max_diff"] = float(
np.abs(ort_scores - ref_scores.cpu().numpy()).max())
# Tolerance resolution: either a float applied uniformly, or a dict
# keyed by verification output name (with optional "default" key).
# Detection boxes get a resolution-scaled tolerance when only a
# float is supplied — exp() in the FCOS regression path amplifies
# FP kernel-dispatch differences to pixel-scale absolute diffs.
if isinstance(tolerance, dict):
default_tol = float(tolerance.get("default", 5e-2))
def _tol_for(key):
return float(tolerance.get(key, default_tol))
verification["tolerance"] = dict(tolerance)
else:
base = float(tolerance)
box_tol = max(base, backbone_resolution * 5e-3)
def _tol_for(key):
return box_tol if key == "detection_boxes_max_diff" else base
verification["tolerance"] = base
verification["detection_boxes_tolerance"] = box_tol
for key, val in list(verification.items()):
if not key.endswith("_max_diff"):
continue
t = _tol_for(key)
if val > t:
raise RuntimeError(
f"ONNX/PyTorch divergence in {key}: {val:.2e} > tolerance {t:.2e}"
)
result["verification"] = verification
return result