argus.py: factor padding_mode through DPT blocks; add depth crop_border kwarg; batched correspond. README: align with shipped 3M cofiber detection head, drop FCOS framing, fix file sizes/param counts, add real IN1k val comparison, document qkv-bias choice. eval JSON: strip personal paths.
5eaa5db verified | """ | |
| Argus: multi-task perception on a single EUPE-ViT-B backbone. | |
| from transformers import AutoModel | |
| model = AutoModel.from_pretrained("phanerozoic/argus", trust_remote_code=True) | |
| result = model.perceive(image) | |
| The EUPE-ViT-B backbone architecture, all supporting layers, and the Argus | |
| task heads are inlined below. The backbone code is reproduced from | |
| facebookresearch/EUPE (Meta FAIR) under the FAIR Research License. | |
| """ | |
| import math | |
| import time | |
| from functools import partial | |
| from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn.init | |
| from PIL import Image | |
| from torch import Tensor, nn | |
| from torchvision.ops import nms | |
| from torchvision.transforms import v2 | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| # =========================================================================== | |
| # EUPE backbone — vendored verbatim from facebookresearch/EUPE | |
| # =========================================================================== | |
| # ---------- utility helpers (from eupe/utils/utils.py) --------------------- | |
| def cat_keep_shapes(x_list: List[Tensor]) -> Tuple[Tensor, List[Tuple[int]], List[int]]: | |
| shapes = [x.shape for x in x_list] | |
| num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list] | |
| flattened = torch.cat([x.flatten(0, -2) for x in x_list]) | |
| return flattened, shapes, num_tokens | |
| def uncat_with_shapes(flattened: Tensor, shapes: List[Tuple[int]], num_tokens: List[int]) -> List[Tensor]: | |
| outputs_splitted = torch.split_with_sizes(flattened, num_tokens, dim=0) | |
| shapes_adjusted = [shape[:-1] + torch.Size([flattened.shape[-1]]) for shape in shapes] | |
| outputs_reshaped = [o.reshape(shape) for o, shape in zip(outputs_splitted, shapes_adjusted)] | |
| return outputs_reshaped | |
| def named_apply( | |
| fn: Callable, | |
| module: nn.Module, | |
| name: str = "", | |
| depth_first: bool = True, | |
| include_root: bool = False, | |
| ) -> nn.Module: | |
| if not depth_first and include_root: | |
| fn(module=module, name=name) | |
| for child_name, child_module in module.named_children(): | |
| child_name = ".".join((name, child_name)) if name else child_name | |
| named_apply( | |
| fn=fn, | |
| module=child_module, | |
| name=child_name, | |
| depth_first=depth_first, | |
| include_root=True, | |
| ) | |
| if depth_first and include_root: | |
| fn(module=module, name=name) | |
| return module | |
| # ---------- RMSNorm (from eupe/layers/rms_norm.py) ------------------------- | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-5): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| self.eps = eps | |
| def reset_parameters(self) -> None: | |
| nn.init.constant_(self.weight, 1) | |
| def _norm(self, x: Tensor) -> Tensor: | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x: Tensor) -> Tensor: | |
| output = self._norm(x.float()).type_as(x) | |
| return output * self.weight | |
| # ---------- LayerScale (from eupe/layers/layer_scale.py) ------------------- | |
| class LayerScale(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| init_values: Union[float, Tensor] = 1e-5, | |
| inplace: bool = False, | |
| device=None, | |
| ) -> None: | |
| super().__init__() | |
| self.inplace = inplace | |
| self.gamma = nn.Parameter(torch.empty(dim, device=device)) | |
| self.init_values = init_values | |
| def reset_parameters(self): | |
| nn.init.constant_(self.gamma, self.init_values) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return x.mul_(self.gamma) if self.inplace else x * self.gamma | |
| # ---------- PatchEmbed (from eupe/layers/patch_embed.py) ------------------- | |
| def make_2tuple(x): | |
| if isinstance(x, tuple): | |
| assert len(x) == 2 | |
| return x | |
| assert isinstance(x, int) | |
| return (x, x) | |
| class PatchEmbed(nn.Module): | |
| def __init__( | |
| self, | |
| img_size: Union[int, Tuple[int, int]] = 224, | |
| patch_size: Union[int, Tuple[int, int]] = 16, | |
| in_chans: int = 3, | |
| embed_dim: int = 768, | |
| norm_layer: Optional[Callable] = None, | |
| flatten_embedding: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| image_HW = make_2tuple(img_size) | |
| patch_HW = make_2tuple(patch_size) | |
| patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1]) | |
| self.img_size = image_HW | |
| self.patch_size = patch_HW | |
| self.patches_resolution = patch_grid_size | |
| self.num_patches = patch_grid_size[0] * patch_grid_size[1] | |
| self.in_chans = in_chans | |
| self.embed_dim = embed_dim | |
| self.flatten_embedding = flatten_embedding | |
| self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) | |
| self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() | |
| def forward(self, x: Tensor) -> Tensor: | |
| _, _, H, W = x.shape | |
| x = self.proj(x) | |
| H, W = x.size(2), x.size(3) | |
| x = x.flatten(2).transpose(1, 2) | |
| x = self.norm(x) | |
| if not self.flatten_embedding: | |
| x = x.reshape(-1, H, W, self.embed_dim) | |
| return x | |
| def reset_parameters(self): | |
| k = 1 / (self.in_chans * (self.patch_size[0] ** 2)) | |
| nn.init.uniform_(self.proj.weight, -math.sqrt(k), math.sqrt(k)) | |
| if self.proj.bias is not None: | |
| nn.init.uniform_(self.proj.bias, -math.sqrt(k), math.sqrt(k)) | |
| # ---------- RoPE (from eupe/layers/rope_position_encoding.py) -------------- | |
| class RopePositionEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| *, | |
| num_heads: int, | |
| base: Optional[float] = 100.0, | |
| min_period: Optional[float] = None, | |
| max_period: Optional[float] = None, | |
| normalize_coords: Literal["min", "max", "separate"] = "separate", | |
| shift_coords: Optional[float] = None, | |
| jitter_coords: Optional[float] = None, | |
| rescale_coords: Optional[float] = None, | |
| dtype: Optional[torch.dtype] = None, | |
| device: Optional[torch.device] = None, | |
| ): | |
| super().__init__() | |
| assert embed_dim % (4 * num_heads) == 0 | |
| both_periods = min_period is not None and max_period is not None | |
| if (base is None and not both_periods) or (base is not None and both_periods): | |
| raise ValueError("Either `base` or `min_period`+`max_period` must be provided.") | |
| D_head = embed_dim // num_heads | |
| self.base = base | |
| self.min_period = min_period | |
| self.max_period = max_period | |
| self.D_head = D_head | |
| self.normalize_coords = normalize_coords | |
| self.shift_coords = shift_coords | |
| self.jitter_coords = jitter_coords | |
| self.rescale_coords = rescale_coords | |
| self.dtype = dtype | |
| self.register_buffer( | |
| "periods", | |
| torch.empty(D_head // 4, device=device, dtype=dtype), | |
| persistent=True, | |
| ) | |
| self._init_weights() | |
| def forward(self, *, H: int, W: int) -> Tuple[Tensor, Tensor]: | |
| device = self.periods.device | |
| dtype = self.dtype | |
| dd = {"device": device, "dtype": dtype} | |
| if self.normalize_coords == "max": | |
| max_HW = max(H, W) | |
| coords_h = torch.arange(0.5, H, **dd) / max_HW | |
| coords_w = torch.arange(0.5, W, **dd) / max_HW | |
| elif self.normalize_coords == "min": | |
| min_HW = min(H, W) | |
| coords_h = torch.arange(0.5, H, **dd) / min_HW | |
| coords_w = torch.arange(0.5, W, **dd) / min_HW | |
| elif self.normalize_coords == "separate": | |
| coords_h = torch.arange(0.5, H, **dd) / H | |
| coords_w = torch.arange(0.5, W, **dd) / W | |
| else: | |
| raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}") | |
| coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) | |
| coords = coords.flatten(0, 1) | |
| coords = 2.0 * coords - 1.0 | |
| if self.training and self.shift_coords is not None: | |
| shift_hw = torch.empty(2, **dd).uniform_(-self.shift_coords, self.shift_coords) | |
| coords += shift_hw[None, :] | |
| if self.training and self.jitter_coords is not None: | |
| jitter_max = np.log(self.jitter_coords) | |
| jitter_min = -jitter_max | |
| jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp() | |
| coords *= jitter_hw[None, :] | |
| if self.training and self.rescale_coords is not None: | |
| rescale_max = np.log(self.rescale_coords) | |
| rescale_min = -rescale_max | |
| rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp() | |
| coords *= rescale_hw | |
| angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] | |
| angles = angles.flatten(1, 2) | |
| angles = angles.tile(2) | |
| cos = torch.cos(angles) | |
| sin = torch.sin(angles) | |
| return (sin, cos) | |
| def _init_weights(self): | |
| device = self.periods.device | |
| dtype = self.dtype | |
| if self.base is not None: | |
| periods = self.base ** ( | |
| 2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2) | |
| ) | |
| else: | |
| base = self.max_period / self.min_period | |
| exponents = torch.linspace(0, 1, self.D_head // 4, device=device, dtype=dtype) | |
| periods = base ** exponents | |
| periods = periods / base | |
| periods = periods * self.max_period | |
| self.periods.data = periods | |
| # ---------- FFN layers (from eupe/layers/ffn_layers.py) -------------------- | |
| class ListForwardMixin(object): | |
| def forward(self, x: Tensor): | |
| raise NotImplementedError | |
| def forward_list(self, x_list: List[Tensor]) -> List[Tensor]: | |
| x_flat, shapes, num_tokens = cat_keep_shapes(x_list) | |
| x_flat = self.forward(x_flat) | |
| return uncat_with_shapes(x_flat, shapes, num_tokens) | |
| class Mlp(nn.Module, ListForwardMixin): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| hidden_features: Optional[int] = None, | |
| out_features: Optional[int] = None, | |
| act_layer: Callable[..., nn.Module] = nn.GELU, | |
| drop: float = 0.0, | |
| bias: bool = True, | |
| device=None, | |
| ) -> None: | |
| super().__init__() | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or in_features | |
| self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, device=device) | |
| self.act = act_layer() | |
| self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, device=device) | |
| self.drop = nn.Dropout(drop) | |
| def forward(self, x: Tensor) -> Tensor: | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.drop(x) | |
| x = self.fc2(x) | |
| x = self.drop(x) | |
| return x | |
| class SwiGLUFFN(nn.Module, ListForwardMixin): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| hidden_features: Optional[int] = None, | |
| out_features: Optional[int] = None, | |
| act_layer: Optional[Callable[..., nn.Module]] = None, | |
| drop: float = 0.0, | |
| bias: bool = True, | |
| align_to: int = 8, | |
| device=None, | |
| ) -> None: | |
| super().__init__() | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or in_features | |
| d = int(hidden_features * 2 / 3) | |
| swiglu_hidden_features = d + (-d % align_to) | |
| self.w1 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device) | |
| self.w2 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device) | |
| self.w3 = nn.Linear(swiglu_hidden_features, out_features, bias=bias, device=device) | |
| def forward(self, x: Tensor) -> Tensor: | |
| x1 = self.w1(x) | |
| x2 = self.w2(x) | |
| hidden = F.silu(x1) * x2 | |
| return self.w3(hidden) | |
| # ---------- Attention (from eupe/layers/attention.py) ---------------------- | |
| def rope_rotate_half(x: Tensor) -> Tensor: | |
| x1, x2 = x.chunk(2, dim=-1) | |
| return torch.cat([-x2, x1], dim=-1) | |
| def rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor: | |
| return (x * cos) + (rope_rotate_half(x) * sin) | |
| class LinearKMaskedBias(nn.Linear): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| o = self.out_features | |
| assert o % 3 == 0 | |
| if self.bias is not None: | |
| self.register_buffer("bias_mask", torch.full_like(self.bias, fill_value=math.nan)) | |
| def forward(self, input: Tensor) -> Tensor: | |
| masked_bias = self.bias * self.bias_mask.to(self.bias.dtype) if self.bias is not None else None | |
| return F.linear(input, self.weight, masked_bias) | |
| class SelfAttention(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int = 8, | |
| qkv_bias: bool = False, | |
| proj_bias: bool = True, | |
| attn_drop: float = 0.0, | |
| proj_drop: float = 0.0, | |
| mask_k_bias: bool = False, | |
| device=None, | |
| ) -> None: | |
| super().__init__() | |
| self.num_heads = num_heads | |
| head_dim = dim // num_heads | |
| self.scale = head_dim ** -0.5 | |
| linear_class = LinearKMaskedBias if mask_k_bias else nn.Linear | |
| self.qkv = linear_class(dim, dim * 3, bias=qkv_bias, device=device) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(dim, dim, bias=proj_bias, device=device) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def apply_rope(self, q: Tensor, k: Tensor, rope) -> Tuple[Tensor, Tensor]: | |
| q_dtype = q.dtype | |
| k_dtype = k.dtype | |
| sin, cos = rope | |
| rope_dtype = sin.dtype | |
| q = q.to(dtype=rope_dtype) | |
| k = k.to(dtype=rope_dtype) | |
| N = q.shape[-2] | |
| prefix = N - sin.shape[-2] | |
| assert prefix >= 0 | |
| q_prefix = q[:, :, :prefix, :] | |
| q = rope_apply(q[:, :, prefix:, :], sin, cos) | |
| q = torch.cat((q_prefix, q), dim=-2) | |
| k_prefix = k[:, :, :prefix, :] | |
| k = rope_apply(k[:, :, prefix:, :], sin, cos) | |
| k = torch.cat((k_prefix, k), dim=-2) | |
| q = q.to(dtype=q_dtype) | |
| k = k.to(dtype=k_dtype) | |
| return q, k | |
| def forward(self, x: Tensor, attn_bias=None, rope=None) -> Tensor: | |
| qkv = self.qkv(x) | |
| attn_v = self.compute_attention(qkv=qkv, attn_bias=attn_bias, rope=rope) | |
| x = self.proj(attn_v) | |
| x = self.proj_drop(x) | |
| return x | |
| def forward_list(self, x_list, attn_bias=None, rope_list=None) -> List[Tensor]: | |
| assert len(x_list) == len(rope_list) | |
| x_flat, shapes, num_tokens = cat_keep_shapes(x_list) | |
| qkv_flat = self.qkv(x_flat) | |
| qkv_list = uncat_with_shapes(qkv_flat, shapes, num_tokens) | |
| att_out = [] | |
| for _, (qkv, _, rope) in enumerate(zip(qkv_list, shapes, rope_list)): | |
| att_out.append(self.compute_attention(qkv, attn_bias=attn_bias, rope=rope)) | |
| x_flat, shapes, num_tokens = cat_keep_shapes(att_out) | |
| x_flat = self.proj(x_flat) | |
| return uncat_with_shapes(x_flat, shapes, num_tokens) | |
| def compute_attention(self, qkv: Tensor, attn_bias=None, rope=None) -> Tensor: | |
| assert attn_bias is None | |
| B, N, _ = qkv.shape | |
| C = self.qkv.in_features | |
| qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads) | |
| q, k, v = torch.unbind(qkv, 2) | |
| q, k, v = [t.transpose(1, 2) for t in [q, k, v]] | |
| if rope is not None: | |
| q, k = self.apply_rope(q, k, rope) | |
| x = torch.nn.functional.scaled_dot_product_attention(q, k, v) | |
| x = x.transpose(1, 2) | |
| return x.reshape([B, N, C]) | |
| # ---------- Block (from eupe/layers/block.py) ------------------------------ | |
| class SelfAttentionBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int, | |
| ffn_ratio: float = 4.0, | |
| qkv_bias: bool = False, | |
| proj_bias: bool = True, | |
| ffn_bias: bool = True, | |
| drop: float = 0.0, | |
| attn_drop: float = 0.0, | |
| init_values=None, | |
| drop_path: float = 0.0, | |
| act_layer: Callable[..., nn.Module] = nn.GELU, | |
| norm_layer: Callable[..., nn.Module] = nn.LayerNorm, | |
| attn_class: Callable[..., nn.Module] = SelfAttention, | |
| ffn_layer: Callable[..., nn.Module] = Mlp, | |
| mask_k_bias: bool = False, | |
| device=None, | |
| ) -> None: | |
| super().__init__() | |
| self.norm1 = norm_layer(dim) | |
| self.attn = attn_class( | |
| dim, | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| proj_bias=proj_bias, | |
| attn_drop=attn_drop, | |
| proj_drop=drop, | |
| mask_k_bias=mask_k_bias, | |
| device=device, | |
| ) | |
| self.ls1 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity() | |
| self.norm2 = norm_layer(dim) | |
| mlp_hidden_dim = int(dim * ffn_ratio) | |
| self.mlp = ffn_layer( | |
| in_features=dim, | |
| hidden_features=mlp_hidden_dim, | |
| act_layer=act_layer, | |
| drop=drop, | |
| bias=ffn_bias, | |
| device=device, | |
| ) | |
| self.ls2 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity() | |
| self.sample_drop_ratio = drop_path | |
| def _maybe_index_rope(rope, indices: Tensor): | |
| if rope is None: | |
| return None | |
| sin, cos = rope | |
| assert sin.ndim == cos.ndim | |
| if sin.ndim == 4: | |
| return sin[indices], cos[indices] | |
| return sin, cos | |
| def _forward_list(self, x_list: List[Tensor], rope_list=None) -> List[Tensor]: | |
| b_list = [x.shape[0] for x in x_list] | |
| sample_subset_sizes = [max(int(b * (1 - self.sample_drop_ratio)), 1) for b in b_list] | |
| if self.training and self.sample_drop_ratio > 0.0: | |
| residual_scale_factors = [b / s for b, s in zip(b_list, sample_subset_sizes)] | |
| indices_1_list = [ | |
| torch.randperm(b, device=x.device)[:s] | |
| for x, b, s in zip(x_list, b_list, sample_subset_sizes) | |
| ] | |
| x_subset_1_list = [x[i] for x, i in zip(x_list, indices_1_list)] | |
| if rope_list is not None: | |
| rope_subset_list = [ | |
| self._maybe_index_rope(r, i) for r, i in zip(rope_list, indices_1_list) | |
| ] | |
| else: | |
| rope_subset_list = rope_list | |
| flattened, shapes, num_tokens = cat_keep_shapes(x_subset_1_list) | |
| norm1 = uncat_with_shapes(self.norm1(flattened), shapes, num_tokens) | |
| residual_1_list = self.attn.forward_list(norm1, rope_list=rope_subset_list) | |
| x_attn_list = [ | |
| torch.index_add(x, dim=0, source=self.ls1(r1), index=i1, alpha=rsf) | |
| for x, r1, i1, rsf in zip(x_list, residual_1_list, indices_1_list, residual_scale_factors) | |
| ] | |
| indices_2_list = [ | |
| torch.randperm(b, device=x.device)[:s] | |
| for x, b, s in zip(x_list, b_list, sample_subset_sizes) | |
| ] | |
| x_subset_2_list = [x[i] for x, i in zip(x_attn_list, indices_2_list)] | |
| flattened, shapes, num_tokens = cat_keep_shapes(x_subset_2_list) | |
| norm2_list = uncat_with_shapes(self.norm2(flattened), shapes, num_tokens) | |
| residual_2_list = self.mlp.forward_list(norm2_list) | |
| x_ffn = [ | |
| torch.index_add(xa, dim=0, source=self.ls2(r2), index=i2, alpha=rsf) | |
| for xa, r2, i2, rsf in zip(x_attn_list, residual_2_list, indices_2_list, residual_scale_factors) | |
| ] | |
| else: | |
| x_out = [] | |
| for x, rope in zip(x_list, rope_list): | |
| x_attn = x + self.ls1(self.attn(self.norm1(x), rope=rope)) | |
| x_ffn = x_attn + self.ls2(self.mlp(self.norm2(x_attn))) | |
| x_out.append(x_ffn) | |
| x_ffn = x_out | |
| return x_ffn | |
| def forward(self, x_or_x_list, rope_or_rope_list=None) -> List[Tensor]: | |
| if isinstance(x_or_x_list, Tensor): | |
| return self._forward_list([x_or_x_list], rope_list=[rope_or_rope_list])[0] | |
| elif isinstance(x_or_x_list, list): | |
| if rope_or_rope_list is None: | |
| rope_or_rope_list = [None for _ in x_or_x_list] | |
| return self._forward_list(x_or_x_list, rope_list=rope_or_rope_list) | |
| raise AssertionError | |
| # ---------- DinoVisionTransformer (from eupe/models/vision_transformer.py) | |
| ffn_layer_dict = { | |
| "mlp": Mlp, | |
| "swiglu": SwiGLUFFN, | |
| "swiglu32": partial(SwiGLUFFN, align_to=32), | |
| "swiglu64": partial(SwiGLUFFN, align_to=64), | |
| "swiglu128": partial(SwiGLUFFN, align_to=128), | |
| } | |
| norm_layer_dict = { | |
| "layernorm": partial(nn.LayerNorm, eps=1e-6), | |
| "layernormbf16": partial(nn.LayerNorm, eps=1e-5), | |
| "rmsnorm": RMSNorm, | |
| } | |
| dtype_dict = { | |
| "fp32": torch.float32, | |
| "fp16": torch.float16, | |
| "bf16": torch.bfloat16, | |
| } | |
| def init_weights_vit(module: nn.Module, name: str = ""): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.trunc_normal_(module.weight, std=0.02) | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| if hasattr(module, "bias_mask") and module.bias_mask is not None: | |
| o = module.out_features | |
| module.bias_mask.fill_(1) | |
| module.bias_mask[o // 3 : 2 * o // 3].fill_(0) | |
| if isinstance(module, nn.LayerNorm): | |
| module.reset_parameters() | |
| if isinstance(module, LayerScale): | |
| module.reset_parameters() | |
| if isinstance(module, PatchEmbed): | |
| module.reset_parameters() | |
| if isinstance(module, RMSNorm): | |
| module.reset_parameters() | |
| class DinoVisionTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| img_size: int = 224, | |
| patch_size: int = 16, | |
| in_chans: int = 3, | |
| pos_embed_rope_base: float = 100.0, | |
| pos_embed_rope_min_period: Optional[float] = None, | |
| pos_embed_rope_max_period: Optional[float] = None, | |
| pos_embed_rope_normalize_coords: Literal["min", "max", "separate"] = "separate", | |
| pos_embed_rope_shift_coords: Optional[float] = None, | |
| pos_embed_rope_jitter_coords: Optional[float] = None, | |
| pos_embed_rope_rescale_coords: Optional[float] = None, | |
| pos_embed_rope_dtype: str = "bf16", | |
| embed_dim: int = 768, | |
| depth: int = 12, | |
| num_heads: int = 12, | |
| ffn_ratio: float = 4.0, | |
| qkv_bias: bool = True, | |
| drop_path_rate: float = 0.0, | |
| layerscale_init: Optional[float] = None, | |
| norm_layer: str = "layernorm", | |
| ffn_layer: str = "mlp", | |
| ffn_bias: bool = True, | |
| proj_bias: bool = True, | |
| n_storage_tokens: int = 0, | |
| mask_k_bias: bool = False, | |
| untie_cls_and_patch_norms: bool = False, | |
| untie_global_and_local_cls_norm: bool = False, | |
| device: Any = None, | |
| **ignored_kwargs, | |
| ): | |
| super().__init__() | |
| del ignored_kwargs | |
| norm_layer_cls = norm_layer_dict[norm_layer] | |
| self.num_features = self.embed_dim = embed_dim | |
| self.n_blocks = depth | |
| self.num_heads = num_heads | |
| self.patch_size = patch_size | |
| self.patch_embed = PatchEmbed( | |
| img_size=img_size, | |
| patch_size=patch_size, | |
| in_chans=in_chans, | |
| embed_dim=embed_dim, | |
| flatten_embedding=False, | |
| ) | |
| self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, device=device)) | |
| self.n_storage_tokens = n_storage_tokens | |
| if self.n_storage_tokens > 0: | |
| self.storage_tokens = nn.Parameter(torch.empty(1, n_storage_tokens, embed_dim, device=device)) | |
| self.rope_embed = RopePositionEmbedding( | |
| embed_dim=embed_dim, | |
| num_heads=num_heads, | |
| base=pos_embed_rope_base, | |
| min_period=pos_embed_rope_min_period, | |
| max_period=pos_embed_rope_max_period, | |
| normalize_coords=pos_embed_rope_normalize_coords, | |
| shift_coords=pos_embed_rope_shift_coords, | |
| jitter_coords=pos_embed_rope_jitter_coords, | |
| rescale_coords=pos_embed_rope_rescale_coords, | |
| dtype=dtype_dict[pos_embed_rope_dtype], | |
| device=device, | |
| ) | |
| ffn_layer_cls = ffn_layer_dict[ffn_layer] | |
| ffn_ratio_sequence = [ffn_ratio] * depth | |
| blocks_list = [ | |
| SelfAttentionBlock( | |
| dim=embed_dim, | |
| num_heads=num_heads, | |
| ffn_ratio=ffn_ratio_sequence[i], | |
| qkv_bias=qkv_bias, | |
| proj_bias=proj_bias, | |
| ffn_bias=ffn_bias, | |
| drop_path=drop_path_rate, | |
| norm_layer=norm_layer_cls, | |
| act_layer=nn.GELU, | |
| ffn_layer=ffn_layer_cls, | |
| init_values=layerscale_init, | |
| mask_k_bias=mask_k_bias, | |
| device=device, | |
| ) | |
| for i in range(depth) | |
| ] | |
| self.chunked_blocks = False | |
| self.blocks = nn.ModuleList(blocks_list) | |
| self.norm = norm_layer_cls(embed_dim) | |
| self.untie_cls_and_patch_norms = untie_cls_and_patch_norms | |
| self.cls_norm = norm_layer_cls(embed_dim) if untie_cls_and_patch_norms else None | |
| self.untie_global_and_local_cls_norm = untie_global_and_local_cls_norm | |
| self.local_cls_norm = norm_layer_cls(embed_dim) if untie_global_and_local_cls_norm else None | |
| self.head = nn.Identity() | |
| self.mask_token = nn.Parameter(torch.empty(1, embed_dim, device=device)) | |
| def init_weights(self): | |
| self.rope_embed._init_weights() | |
| nn.init.normal_(self.cls_token, std=0.02) | |
| if self.n_storage_tokens > 0: | |
| nn.init.normal_(self.storage_tokens, std=0.02) | |
| nn.init.zeros_(self.mask_token) | |
| named_apply(init_weights_vit, self) | |
| def prepare_tokens_with_masks(self, x: Tensor, masks=None) -> Tuple[Tensor, Tuple[int, int]]: | |
| x = self.patch_embed(x) | |
| B, H, W, _ = x.shape | |
| x = x.flatten(1, 2) | |
| if masks is not None: | |
| x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) | |
| cls_token = self.cls_token | |
| else: | |
| cls_token = self.cls_token + 0 * self.mask_token | |
| if self.n_storage_tokens > 0: | |
| storage_tokens = self.storage_tokens | |
| else: | |
| storage_tokens = torch.empty( | |
| 1, 0, cls_token.shape[-1], | |
| dtype=cls_token.dtype, device=cls_token.device, | |
| ) | |
| x = torch.cat( | |
| [cls_token.expand(B, -1, -1), storage_tokens.expand(B, -1, -1), x], | |
| dim=1, | |
| ) | |
| return x, (H, W) | |
| def forward_features_list(self, x_list: List[Tensor], masks_list: List[Tensor]) -> List[Dict[str, Tensor]]: | |
| x = [] | |
| rope = [] | |
| for t_x, t_masks in zip(x_list, masks_list): | |
| t2_x, hw_tuple = self.prepare_tokens_with_masks(t_x, t_masks) | |
| x.append(t2_x) | |
| rope.append(hw_tuple) | |
| for blk in self.blocks: | |
| if self.rope_embed is not None: | |
| rope_sincos = [self.rope_embed(H=H, W=W) for H, W in rope] | |
| else: | |
| rope_sincos = [None for _ in rope] | |
| x = blk(x, rope_sincos) | |
| all_x = x | |
| output = [] | |
| for idx, (x, masks) in enumerate(zip(all_x, masks_list)): | |
| if self.untie_cls_and_patch_norms or self.untie_global_and_local_cls_norm: | |
| if self.untie_global_and_local_cls_norm and self.training and idx == 1: | |
| x_norm_cls_reg = self.local_cls_norm(x[:, : self.n_storage_tokens + 1]) | |
| elif self.untie_cls_and_patch_norms: | |
| x_norm_cls_reg = self.cls_norm(x[:, : self.n_storage_tokens + 1]) | |
| else: | |
| x_norm_cls_reg = self.norm(x[:, : self.n_storage_tokens + 1]) | |
| x_norm_patch = self.norm(x[:, self.n_storage_tokens + 1 :]) | |
| else: | |
| x_norm = self.norm(x) | |
| x_norm_cls_reg = x_norm[:, : self.n_storage_tokens + 1] | |
| x_norm_patch = x_norm[:, self.n_storage_tokens + 1 :] | |
| output.append({ | |
| "x_norm_clstoken": x_norm_cls_reg[:, 0], | |
| "x_storage_tokens": x_norm_cls_reg[:, 1:], | |
| "x_norm_patchtokens": x_norm_patch, | |
| "x_prenorm": x, | |
| "masks": masks, | |
| }) | |
| return output | |
| def forward_features(self, x, masks: Optional[Tensor] = None): | |
| if isinstance(x, torch.Tensor): | |
| return self.forward_features_list([x], [masks])[0] | |
| return self.forward_features_list(x, masks) | |
| def forward(self, *args, is_training: bool = False, **kwargs): | |
| ret = self.forward_features(*args, **kwargs) | |
| if is_training: | |
| return ret | |
| return self.head(ret["x_norm_clstoken"]) | |
| def build_eupe_vitb16() -> DinoVisionTransformer: | |
| # qkv_bias=False, mask_k_bias=False: the upstream EUPE-ViT-B release shipped | |
| # with `qkv.bias_mask` filled with zeros, which makes the effective qkv bias | |
| # zero at every block (masked_bias = bias * 0 = 0). We drop the bias parameter | |
| # entirely here — the computation is bitwise-equivalent in fp32, bf16 output | |
| # drift is sub-ULP and absorbed by every head except DPT depth (where it | |
| # appears as ~2cm noise against a 39cm RMSE, i.e. below the head's own floor). | |
| return DinoVisionTransformer( | |
| img_size=224, | |
| patch_size=16, | |
| in_chans=3, | |
| pos_embed_rope_base=100, | |
| pos_embed_rope_normalize_coords="separate", | |
| pos_embed_rope_rescale_coords=2, | |
| pos_embed_rope_dtype="fp32", | |
| embed_dim=768, | |
| depth=12, | |
| num_heads=12, | |
| ffn_ratio=4, | |
| qkv_bias=False, | |
| drop_path_rate=0.0, | |
| layerscale_init=1.0e-05, | |
| norm_layer="layernormbf16", | |
| ffn_layer="mlp", | |
| ffn_bias=True, | |
| proj_bias=True, | |
| n_storage_tokens=4, | |
| mask_k_bias=False, | |
| ) | |
| # =========================================================================== | |
| # Argus task heads | |
| # =========================================================================== | |
| def make_eupe_transform(resize_size: int): | |
| return v2.Compose([ | |
| v2.ToImage(), | |
| v2.Resize((resize_size, resize_size), antialias=True), | |
| v2.ToDtype(torch.float32, scale=True), | |
| v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ]) | |
| def _normalize_image_input(image_or_images) -> Tuple[bool, list]: | |
| """Returns (was_single, [images]). Accepts a PIL.Image or an iterable of them.""" | |
| if isinstance(image_or_images, Image.Image): | |
| return True, [image_or_images] | |
| images = list(image_or_images) | |
| if not images: | |
| raise ValueError("empty image list") | |
| for i, img in enumerate(images): | |
| if not isinstance(img, Image.Image): | |
| raise TypeError(f"images[{i}] is {type(img).__name__}, expected PIL.Image") | |
| return False, images | |
| class _BackboneExportWrapper(nn.Module): | |
| """ONNX-friendly wrapper: returns (cls, spatial) instead of a dict.""" | |
| def __init__(self, backbone: nn.Module): | |
| super().__init__() | |
| self.backbone = backbone | |
| def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: | |
| out = self.backbone.forward_features(x) | |
| cls = out["x_norm_clstoken"] | |
| patches = out["x_norm_patchtokens"] | |
| B, N, D = patches.shape | |
| h = w = int(N ** 0.5) | |
| spatial = patches.permute(0, 2, 1).reshape(B, D, h, w) | |
| return cls, spatial | |
| class _SegHeadExportWrapper(nn.Module): | |
| """ONNX-friendly wrapper: seg head + bilinear upsample to input resolution. | |
| The bare seg head emits stride-16 logits (e.g. [B, 150, 40, 40] at 640px | |
| input). model.segment() upsamples those to the input resolution before | |
| argmax. This wrapper folds the upsample into the graph so the ONNX seg | |
| output is already at input resolution — consumers argmax directly without | |
| a separate interpolation step. | |
| """ | |
| def __init__(self, seg_head: nn.Module, resolution: int): | |
| super().__init__() | |
| self.seg_head = seg_head | |
| self.resolution = resolution | |
| def forward(self, spatial_features: Tensor) -> Tensor: | |
| logits = self.seg_head(spatial_features) | |
| return F.interpolate(logits, size=(self.resolution, self.resolution), | |
| mode="bilinear", align_corners=False) | |
| class _DepthHeadExportWrapper(nn.Module): | |
| """ONNX-friendly wrapper for the DPT depth head. | |
| DPTDepthDecoder.forward takes (intermediates: List[Tensor], H: int, W: int), | |
| which torch.onnx.export cannot trace cleanly because the List contains four | |
| tensors and H/W are Python ints. The wrapper accepts the four intermediate | |
| ViT-block activations as separate positional tensor inputs and forwards them | |
| to the underlying decoder with the captured H and W. | |
| """ | |
| def __init__(self, depth_head: nn.Module, H: int, W: int): | |
| super().__init__() | |
| self.depth_head = depth_head | |
| self.H = H | |
| self.W = W | |
| def forward(self, inter0: Tensor, inter1: Tensor, inter2: Tensor, inter3: Tensor) -> Tensor: | |
| return self.depth_head([inter0, inter1, inter2, inter3], self.H, self.W) | |
| class _ClassifierExportWrapper(nn.Module): | |
| """ONNX-friendly wrapper for the ImageNet linear-softmax classifier. | |
| Takes the backbone's CLS token, L2-normalizes, applies the stored | |
| Linear(embed_dim, 1000) weight + bias, and returns a softmax | |
| distribution over the 1000 ImageNet classes. The weight and bias are | |
| captured as buffers so the graph is self-contained — no separate | |
| weight file needed for classification inference. | |
| """ | |
| def __init__(self, class_weight: Tensor, class_bias: Tensor): | |
| super().__init__() | |
| self.register_buffer("weight", class_weight.float().clone()) | |
| self.register_buffer("bias", class_bias.float().clone()) | |
| def forward(self, cls_token: Tensor) -> Tensor: | |
| x = F.normalize(cls_token, dim=-1) | |
| logits = F.linear(x, self.weight, self.bias) | |
| return F.softmax(logits, dim=-1) | |
| class _ONNXBatchedNMS(torch.autograd.Function): | |
| """Autograd wrapper that exports to ONNX NonMaxSuppression (opset >= 10). | |
| ONNX's NonMaxSuppression handles batched multi-class NMS natively: | |
| boxes [B, N, 4] in [y1, x1, y2, x2] order (center_point_box=0) | |
| scores [B, C, N] | |
| -> selected_indices [M, 3] where each row is [batch, class, box] | |
| The eager forward path reproduces this via torchvision.ops.nms so | |
| PyTorch tracing and verify=True both work without calling into | |
| ORT for the reference. | |
| """ | |
| def symbolic(g, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold): | |
| return g.op( | |
| "NonMaxSuppression", | |
| boxes, scores, | |
| max_output_boxes_per_class, | |
| iou_threshold, | |
| score_threshold, | |
| center_point_box_i=0, | |
| ) | |
| def forward(ctx, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold): | |
| from torchvision.ops import nms as tv_nms | |
| B, N, _ = boxes.shape | |
| _, C, _ = scores.shape | |
| max_out = int(max_output_boxes_per_class.item()) | |
| iou_thr = float(iou_threshold.item()) | |
| score_thr = float(score_threshold.item()) | |
| results: List[List[int]] = [] | |
| for b in range(B): | |
| for c in range(C): | |
| sc = scores[b, c] | |
| mask = sc > score_thr | |
| if not mask.any(): | |
| continue | |
| idx = mask.nonzero(as_tuple=True)[0] | |
| # tv_nms expects [x1, y1, x2, y2]; our boxes are [y1, x1, y2, x2]. | |
| bx_xyxy = boxes[b, idx][:, [1, 0, 3, 2]] | |
| keep = tv_nms(bx_xyxy, sc[idx], iou_thr)[:max_out] | |
| for k in keep.tolist(): | |
| results.append([b, c, int(idx[k].item())]) | |
| if not results: | |
| return torch.zeros((0, 3), dtype=torch.long, device=boxes.device) | |
| return torch.tensor(results, dtype=torch.long, device=boxes.device) | |
| class _DetectionHeadExportWrapper(nn.Module): | |
| """ONNX-friendly wrapper for the detection head (simple FPN + FCOS). | |
| Takes backbone stride-16 spatial features and returns decoded | |
| per-location predictions concatenated across all five FPN levels. | |
| Without NMS (default): | |
| - boxes [B, N_total, 4] xyxy in input-resolution pixels, | |
| decoded as (location - exp(reg)) / | |
| (location + exp(reg)) and clamped. | |
| - scores [B, N_total, num_classes] | |
| sigmoid(cls_logits) * sigmoid(centerness). | |
| With NMS (include_nms=True): | |
| - boxes [M, 4] xyxy in input-resolution pixels | |
| - scores [M] | |
| - class_labels [M] int64 class index | |
| - batch_indices[M] int64 batch index | |
| N_total = sum(H_i * W_i) across strides [8, 16, 32, 64, 128]. At | |
| 640px input: 6400 + 1600 + 400 + 100 + 25 = 8525 locations/image. | |
| The NMS variant folds ONNX's NonMaxSuppression (opset >= 10) into | |
| the graph using the configured iou / score / max_detections | |
| parameters, producing a flat list of surviving detections across | |
| all batches and classes. Useful for single-shot TensorRT / mobile | |
| inference. Without NMS the consumer runs their own — hard vs soft, | |
| per-class vs global, threshold tuning — without re-exporting. | |
| """ | |
| def __init__(self, detection_head: nn.Module, resolution: int, | |
| include_nms: bool = False, | |
| nms_iou_threshold: float = 0.5, | |
| nms_score_threshold: float = 0.05, | |
| nms_max_detections: int = 100): | |
| super().__init__() | |
| self.detection_head = detection_head | |
| self.resolution = resolution | |
| self.num_classes = detection_head.num_classes | |
| self.include_nms = include_nms | |
| self.nms_iou_threshold = nms_iou_threshold | |
| self.nms_score_threshold = nms_score_threshold | |
| self.nms_max_detections = nms_max_detections | |
| # Compute per-level spatial sizes from the SimpleFeaturePyramid's actual | |
| # output shapes, not from resolution // stride. The pyramid starts at | |
| # stride-16 backbone features (H = resolution // 16) and produces: | |
| # P3 = 2*H via ConvTranspose2d(stride=2) | |
| # P4 = H via 1x1 + 3x3 convs (no stride) | |
| # P5 = (H+1)//2 via Conv2d(3x3, stride=2, padding=1) | |
| # P6 = (P5+1)//2 via Conv2d on P5 | |
| # P7 = (P6+1)//2 via Conv2d on P6 | |
| # When resolution is a multiple of 128, these match resolution // stride | |
| # exactly; at other resolutions the stride-2 convs round up via the | |
| # padding=1 kernel=3 formula, so P6/P7 are slightly larger than | |
| # nominal stride division suggests. Feature-pyramid-level locations | |
| # still use the nominal FPN_STRIDES for FCOS box decoding because | |
| # that's what eager `model.detect` does. | |
| H = resolution // 16 | |
| p3 = 2 * H | |
| p4 = H | |
| p5 = (H + 1) // 2 | |
| p6 = (p5 + 1) // 2 | |
| p7 = (p6 + 1) // 2 | |
| feat_sizes = [(p3, p3), (p4, p4), (p5, p5), (p6, p6), (p7, p7)] | |
| locs_per_level = [] | |
| for (h, w), s in zip(feat_sizes, FPN_STRIDES): | |
| ys = (torch.arange(h, dtype=torch.float32) + 0.5) * s | |
| xs = (torch.arange(w, dtype=torch.float32) + 0.5) * s | |
| gy, gx = torch.meshgrid(ys, xs, indexing="ij") | |
| locs_per_level.append(torch.stack([gx.flatten(), gy.flatten()], -1)) | |
| all_locs = torch.cat(locs_per_level, 0) | |
| self.register_buffer("all_locs", all_locs) | |
| def forward(self, spatial_features: Tensor): | |
| cls_logits, box_regs, centernesses = self.detection_head(spatial_features) | |
| B = spatial_features.shape[0] | |
| flat_cls = torch.cat( | |
| [c.permute(0, 2, 3, 1).reshape(B, -1, self.num_classes) for c in cls_logits], dim=1) | |
| flat_reg = torch.cat( | |
| [r.permute(0, 2, 3, 1).reshape(B, -1, 4) for r in box_regs], dim=1) | |
| flat_ctr = torch.cat( | |
| [c.permute(0, 2, 3, 1).reshape(B, -1, 1) for c in centernesses], dim=1) | |
| scores = torch.sigmoid(flat_cls) * torch.sigmoid(flat_ctr) | |
| locs = self.all_locs.unsqueeze(0).expand(B, -1, -1) | |
| x1 = (locs[..., 0:1] - flat_reg[..., 0:1]).clamp(0, self.resolution) | |
| y1 = (locs[..., 1:2] - flat_reg[..., 1:2]).clamp(0, self.resolution) | |
| x2 = (locs[..., 0:1] + flat_reg[..., 2:3]).clamp(0, self.resolution) | |
| y2 = (locs[..., 1:2] + flat_reg[..., 3:4]).clamp(0, self.resolution) | |
| boxes = torch.cat([x1, y1, x2, y2], dim=-1) | |
| if not self.include_nms: | |
| return boxes, scores | |
| # ONNX NMS expects boxes in [y1, x1, y2, x2] (center_point_box=0) and | |
| # scores with the class dim in the middle: [B, C, N]. | |
| boxes_yxyx = torch.cat([y1, x1, y2, x2], dim=-1) | |
| scores_bcn = scores.permute(0, 2, 1).contiguous() | |
| max_out = torch.tensor(self.nms_max_detections, dtype=torch.long, device=boxes.device) | |
| iou_thr = torch.tensor(self.nms_iou_threshold, dtype=torch.float32, device=boxes.device) | |
| score_thr = torch.tensor(self.nms_score_threshold, dtype=torch.float32, device=boxes.device) | |
| selected = _ONNXBatchedNMS.apply( | |
| boxes_yxyx, scores_bcn, max_out, iou_thr, score_thr, | |
| ) | |
| batch_idx = selected[:, 0].long() | |
| class_idx = selected[:, 1].long() | |
| box_idx = selected[:, 2].long() | |
| sel_boxes = boxes[batch_idx, box_idx] # [M, 4] xyxy | |
| sel_scores = scores[batch_idx, box_idx, class_idx] # [M] | |
| return sel_boxes, sel_scores, class_idx, batch_idx | |
| class SegmentationHead(nn.Module): | |
| def __init__(self, in_dim: int = 768, num_classes: int = 150): | |
| super().__init__() | |
| self.batchnorm_layer = nn.BatchNorm2d(in_dim) | |
| self.conv = nn.Conv2d(in_dim, num_classes, kernel_size=1) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.conv(self.batchnorm_layer(x)) | |
| class DepthHead(nn.Module): | |
| def __init__(self, in_dim: int = 768, n_bins: int = 256, | |
| min_depth: float = 0.001, max_depth: float = 10.0): | |
| super().__init__() | |
| self.batchnorm_layer = nn.BatchNorm2d(in_dim) | |
| self.conv_depth = nn.Conv2d(in_dim, n_bins, kernel_size=1) | |
| self.min_depth = min_depth | |
| self.max_depth = max_depth | |
| self.n_bins = n_bins | |
| def forward(self, x: Tensor) -> Tensor: | |
| logits = self.conv_depth(self.batchnorm_layer(x)) | |
| logit = torch.relu(logits) + 0.1 | |
| logit = logit / logit.sum(dim=1, keepdim=True) | |
| bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=x.device) | |
| return torch.einsum("bkhw,k->bhw", logit, bins).unsqueeze(1) | |
| # =========================================================================== | |
| # Detection (FCOS with ViTDet-style simple feature pyramid) | |
| # =========================================================================== | |
| FPN_STRIDES = [8, 16, 32, 64, 128] | |
| COCO_CLASSES = [ | |
| "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", | |
| "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", | |
| "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", | |
| "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", | |
| "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", | |
| "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", | |
| "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", | |
| "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", | |
| "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", | |
| "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", | |
| "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", | |
| "toothbrush", | |
| ] | |
| def cofiber_decompose(f: Tensor, n_scales: int) -> List[Tensor]: | |
| """Iterated multi-scale decomposition. Each step subtracts the | |
| downsampled-then-upsampled component of the current residual and | |
| recurses on the remainder. Zero learned parameters. The final entry is | |
| the lowest-frequency remainder.""" | |
| cofibers: List[Tensor] = [] | |
| residual = f | |
| for _ in range(n_scales - 1): | |
| omega = F.avg_pool2d(residual, 2) | |
| sigma_omega = F.interpolate(omega, size=residual.shape[2:], | |
| mode="bilinear", align_corners=False) | |
| cofibers.append(residual - sigma_omega) | |
| residual = omega | |
| cofibers.append(residual) | |
| return cofibers | |
| def make_sin_pos_emb(H: int, W: int, dim: int, device) -> Tensor: | |
| """2D sinusoidal positional encoding over an H x W grid. Concatenated | |
| to the backbone patch features before the head stem.""" | |
| assert dim % 4 == 0, "pos emb dim must be divisible by 4" | |
| d = dim // 4 | |
| ys = torch.arange(H, device=device, dtype=torch.float32) | |
| xs = torch.arange(W, device=device, dtype=torch.float32) | |
| omega = torch.exp(torch.arange(d, device=device, dtype=torch.float32) | |
| * -(math.log(10000.0) / d)) | |
| pe_y = torch.zeros(H, d * 2, device=device) | |
| pe_y[:, 0::2] = torch.sin(ys[:, None] * omega[None, :]) | |
| pe_y[:, 1::2] = torch.cos(ys[:, None] * omega[None, :]) | |
| pe_x = torch.zeros(W, d * 2, device=device) | |
| pe_x[:, 0::2] = torch.sin(xs[:, None] * omega[None, :]) | |
| pe_x[:, 1::2] = torch.cos(xs[:, None] * omega[None, :]) | |
| pos = torch.zeros(dim, H, W, device=device) | |
| pos[:d * 2] = pe_y.permute(1, 0)[:, :, None].expand(-1, H, W) | |
| pos[d * 2:] = pe_x.permute(1, 0)[None, :, :].expand(H, -1, W).permute(1, 0, 2) | |
| return pos.unsqueeze(0) | |
| class ConvGNBlock(nn.Module): | |
| def __init__(self, channels: int): | |
| super().__init__() | |
| self.conv = nn.Conv2d(channels, channels, 3, padding=1) | |
| self.norm = nn.GroupNorm(min(32, channels), channels) | |
| self.act = nn.GELU() | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.act(self.norm(self.conv(x))) | |
| class DWResBlock(nn.Module): | |
| def __init__(self, channels: int): | |
| super().__init__() | |
| self.pw = nn.Conv2d(channels, channels, 1) | |
| self.act = nn.GELU() | |
| self.dw = nn.Conv2d(channels, channels, 3, padding=1, groups=channels) | |
| self.norm = nn.GroupNorm(min(32, channels), channels) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return x + self.norm(self.dw(self.act(self.pw(x)))) | |
| def make_tower(hidden: int, n_std: int, n_dw: int) -> nn.Sequential: | |
| layers: List[nn.Module] = [ConvGNBlock(hidden) for _ in range(n_std)] | |
| layers += [DWResBlock(hidden) for _ in range(n_dw)] | |
| return nn.Sequential(*layers) | |
| class SplitTowerHead(nn.Module): | |
| """Detection head operating on a cofiber decomposition of the frozen | |
| backbone features. Five prediction levels (strides 8, 16, 32, 64, 128): | |
| a stride-8 level synthesized by a transposed convolution from the | |
| stride-16 band and four cofiber bands at strides 16, 32, 64, 128. | |
| Separate classification and regression towers of depth (n_std_layers + | |
| n_dw_layers) with weights shared across levels. Classification via | |
| cosine similarity against frozen CLIP text-encoder embeddings of the | |
| COCO class names; regression via exponentiated LTRB distances with a | |
| learned per-level scale; centerness via a single 1x1 convolution. | |
| Inference-only within Argus: no DFL, no IoU-aware branch, no | |
| per-scale bias. The text_embed buffer is populated by from_pretrained's | |
| state_dict load.""" | |
| def __init__(self, | |
| feat_dim: int = 768, | |
| hidden: int = 160, | |
| n_std_layers: int = 5, | |
| n_dw_layers: int = 4, | |
| n_scales: int = 4, | |
| pos_emb_dim: int = 64, | |
| num_classes: int = 80, | |
| text_embed_dim: int = 768): | |
| super().__init__() | |
| self.n_scales = n_scales | |
| self.pos_emb_dim = pos_emb_dim | |
| self.num_classes = num_classes | |
| self.text_embed_dim = text_embed_dim | |
| n_total = n_scales + 1 | |
| input_dim = feat_dim + pos_emb_dim | |
| self.scale_norms = nn.ModuleList([nn.GroupNorm(1, input_dim) for _ in range(n_scales)]) | |
| self.stem = nn.Conv2d(input_dim, hidden, 1) | |
| self.stem_act = nn.GELU() | |
| self.p3_upsample = nn.ConvTranspose2d(hidden, hidden, 2, stride=2) | |
| self.p3_norm = nn.GroupNorm(min(32, hidden), hidden) | |
| self.lateral_convs = nn.ModuleList([nn.Conv2d(hidden, hidden, 1) for _ in range(n_scales - 1)]) | |
| self.lateral_norms = nn.ModuleList( | |
| [nn.GroupNorm(min(32, hidden), hidden) for _ in range(n_scales - 1)]) | |
| self.cls_tower = make_tower(hidden, n_std_layers, n_dw_layers) | |
| self.reg_tower = make_tower(hidden, n_std_layers, n_dw_layers) | |
| # CLIP text-aligned classifier. The text_embed buffer is filled from | |
| # the state dict at from_pretrained; the zero placeholder here only | |
| # exists so the module can be constructed before weights arrive. | |
| self.register_buffer("text_embed", | |
| torch.zeros(num_classes, text_embed_dim)) | |
| self.cls_project = nn.Linear(hidden, text_embed_dim, bias=False) | |
| self.logit_scale = nn.Parameter(torch.tensor(math.log(1.0 / 0.07))) | |
| self.cls_bias = nn.Parameter(torch.full((num_classes,), -math.log(99))) | |
| self.reg_pred = nn.Conv2d(hidden, 4, 1) | |
| self.ctr_pred = nn.Conv2d(hidden, 1, 1) | |
| self.scale_params = nn.Parameter(torch.ones(n_total)) | |
| def forward(self, spatial: Tensor) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]: | |
| B, C, H_, W_ = spatial.shape | |
| pos = make_sin_pos_emb(H_, W_, self.pos_emb_dim, spatial.device).expand(B, -1, -1, -1) | |
| spatial = torch.cat([spatial, pos], dim=1) | |
| cofibers = cofiber_decompose(spatial, self.n_scales) | |
| scale_features: List[Tensor] = [] | |
| for i, cof in enumerate(cofibers): | |
| x = self.stem_act(self.stem(self.scale_norms[i](cof))) | |
| scale_features.append(x) | |
| # Top-down lateral fusion from coarser to finer scales. | |
| for i in range(len(scale_features) - 2, -1, -1): | |
| coarse_up = F.interpolate(scale_features[i + 1], | |
| size=scale_features[i].shape[2:], | |
| mode="bilinear", align_corners=False) | |
| scale_features[i] = self.lateral_norms[i]( | |
| scale_features[i] + self.lateral_convs[i](coarse_up)) | |
| p3 = self.p3_norm(self.p3_upsample(scale_features[0])) | |
| all_features = [p3] + scale_features | |
| cls_l, reg_l, ctr_l = [], [], [] | |
| for i, x in enumerate(all_features): | |
| cls_feat = self.cls_tower(x) | |
| reg_feat = self.reg_tower(x) | |
| B_, _, Hi, Wi = cls_feat.shape | |
| f = cls_feat.permute(0, 2, 3, 1).reshape(-1, cls_feat.shape[1]) | |
| f_proj = self.cls_project(f) | |
| f_norm = F.normalize(f_proj, p=2, dim=-1) | |
| logits = f_norm @ self.text_embed.t() | |
| cls = (logits * self.logit_scale.exp() + self.cls_bias).reshape( | |
| B_, Hi, Wi, self.num_classes).permute(0, 3, 1, 2) | |
| reg_raw = (self.reg_pred(reg_feat) * self.scale_params[i]).clamp(-10, 10) | |
| reg = reg_raw.exp() | |
| ctr = self.ctr_pred(reg_feat) | |
| cls_l.append(cls) | |
| reg_l.append(reg) | |
| ctr_l.append(ctr) | |
| return cls_l, reg_l, ctr_l | |
| def _make_locations(feature_sizes: List[Tuple[int, int]], strides: List[int], device) -> List[Tensor]: | |
| """Per-level center coordinates of feature-map locations in image space.""" | |
| all_locs = [] | |
| for (h, w), s in zip(feature_sizes, strides): | |
| ys = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * s | |
| xs = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * s | |
| grid_y, grid_x = torch.meshgrid(ys, xs, indexing="ij") | |
| locs = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1) | |
| all_locs.append(locs) | |
| return all_locs | |
| def _decode_detections( | |
| cls_logits_per_level: List[Tensor], | |
| box_regs_per_level: List[Tensor], | |
| centernesses_per_level: List[Tensor], | |
| locations_per_level: List[Tensor], | |
| image_sizes: List[Tuple[int, int]], | |
| score_thresh: float = 0.05, | |
| nms_thresh: float = 0.5, | |
| max_per_level: int = 1000, | |
| max_per_image: int = 100, | |
| ) -> List[Dict[str, Tensor]]: | |
| """Convert per-level logits/regs/centerness into per-image detections (xyxy boxes).""" | |
| B = cls_logits_per_level[0].shape[0] | |
| num_classes = cls_logits_per_level[0].shape[1] | |
| device = cls_logits_per_level[0].device | |
| per_image_results = [] | |
| for image_idx in range(B): | |
| all_boxes, all_scores, all_labels = [], [], [] | |
| for cls_l, reg_l, ctr_l, locs_l in zip( | |
| cls_logits_per_level, box_regs_per_level, centernesses_per_level, locations_per_level | |
| ): | |
| cls = cls_l[image_idx].permute(1, 2, 0).reshape(-1, num_classes) | |
| reg = reg_l[image_idx].permute(1, 2, 0).reshape(-1, 4) | |
| ctr = ctr_l[image_idx].permute(1, 2, 0).reshape(-1) | |
| cls_prob = torch.sigmoid(cls) | |
| ctr_prob = torch.sigmoid(ctr) | |
| scores = cls_prob * ctr_prob[:, None] | |
| mask = scores > score_thresh | |
| if not mask.any(): | |
| continue | |
| cand_loc, cand_cls = mask.nonzero(as_tuple=True) | |
| cand_scores = scores[cand_loc, cand_cls] | |
| if cand_scores.numel() > max_per_level: | |
| top = cand_scores.topk(max_per_level) | |
| cand_scores = top.values | |
| idx = top.indices | |
| cand_loc = cand_loc[idx] | |
| cand_cls = cand_cls[idx] | |
| cand_locs_xy = locs_l[cand_loc] | |
| cand_reg = reg[cand_loc] | |
| boxes = torch.stack([ | |
| cand_locs_xy[:, 0] - cand_reg[:, 0], | |
| cand_locs_xy[:, 1] - cand_reg[:, 1], | |
| cand_locs_xy[:, 0] + cand_reg[:, 2], | |
| cand_locs_xy[:, 1] + cand_reg[:, 3], | |
| ], dim=-1) | |
| all_boxes.append(boxes) | |
| all_scores.append(cand_scores) | |
| all_labels.append(cand_cls) | |
| if all_boxes: | |
| boxes = torch.cat(all_boxes, dim=0) | |
| scores = torch.cat(all_scores, dim=0) | |
| labels = torch.cat(all_labels, dim=0) | |
| H, W = image_sizes[image_idx] | |
| boxes[:, 0::2] = boxes[:, 0::2].clamp(0, W) | |
| boxes[:, 1::2] = boxes[:, 1::2].clamp(0, H) | |
| keep_all = [] | |
| for c in labels.unique(): | |
| cm = labels == c | |
| keep = nms(boxes[cm], scores[cm], nms_thresh) | |
| keep_idx = cm.nonzero(as_tuple=True)[0][keep] | |
| keep_all.append(keep_idx) | |
| keep_all = torch.cat(keep_all, dim=0) | |
| boxes = boxes[keep_all] | |
| scores = scores[keep_all] | |
| labels = labels[keep_all] | |
| if scores.numel() > max_per_image: | |
| top = scores.topk(max_per_image) | |
| boxes = boxes[top.indices] | |
| scores = top.values | |
| labels = labels[top.indices] | |
| else: | |
| boxes = torch.zeros((0, 4), device=device) | |
| scores = torch.zeros((0,), device=device) | |
| labels = torch.zeros((0,), dtype=torch.long, device=device) | |
| per_image_results.append({"boxes": boxes, "scores": scores, "labels": labels}) | |
| return per_image_results | |
| def _letterbox_to_square(image: Image.Image, resolution: int) -> Tuple[Image.Image, float, Tuple[int, int]]: | |
| """Resize preserving aspect ratio and pad bottom/right with black. Matches the training transform.""" | |
| W0, H0 = image.size | |
| scale = resolution / max(H0, W0) | |
| new_w = int(round(W0 * scale)) | |
| new_h = int(round(H0 * scale)) | |
| resized = image.resize((new_w, new_h), Image.BILINEAR) | |
| canvas = Image.new("RGB", (resolution, resolution), (0, 0, 0)) | |
| canvas.paste(resized, (0, 0)) | |
| return canvas, scale, (W0, H0) | |
| # =========================================================================== | |
| # DPT depth decoder (multi-scale, hooks into ViT blocks [2, 5, 8, 11]) | |
| # =========================================================================== | |
| HOOK_BLOCK_INDICES = [2, 5, 8, 11] | |
| N_PREFIX_TOKENS = 5 # 1 CLS + 4 register/storage tokens | |
| class _ResidualConvUnit(nn.Module): | |
| """Two 3x3 conv + BatchNorm blocks with a residual connection. Padding | |
| mode is configurable: the Argus-B DPT depth head trains with reflect | |
| padding to avoid edge artifacts; Argus-Lite ships weights that were | |
| trained with zero padding (the PyTorch default), and switching pad | |
| modes at inference would create a small distribution shift in the | |
| edge regions. Variants pass `padding_mode` to keep their inference | |
| aligned with their training.""" | |
| def __init__(self, dim: int, padding_mode: str = "reflect"): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(dim, dim, 3, padding=1, padding_mode=padding_mode, bias=False) | |
| self.bn1 = nn.BatchNorm2d(dim) | |
| self.conv2 = nn.Conv2d(dim, dim, 3, padding=1, padding_mode=padding_mode, bias=False) | |
| self.bn2 = nn.BatchNorm2d(dim) | |
| self.act = nn.GELU() | |
| def forward(self, x: Tensor) -> Tensor: | |
| return x + self.bn2(self.conv2(self.act(self.bn1(self.conv1(x))))) | |
| class _FeatureFusionBlock(nn.Module): | |
| def __init__(self, dim: int, has_skip: bool = True, padding_mode: str = "reflect"): | |
| super().__init__() | |
| self.rcu1 = _ResidualConvUnit(dim, padding_mode=padding_mode) | |
| self.rcu2 = _ResidualConvUnit(dim, padding_mode=padding_mode) | |
| self.skip_proj = nn.Conv2d(dim, dim, 1) if has_skip else None | |
| def forward(self, x: Tensor, skip: Optional[Tensor] = None) -> Tensor: | |
| if skip is not None and self.skip_proj is not None: | |
| x = x + self.skip_proj(skip) | |
| x = self.rcu1(x) | |
| x = self.rcu2(x) | |
| return F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False) | |
| class _DPTReassemble(nn.Module): | |
| def __init__(self, in_dim: int = 768, out_dim: int = 256): | |
| super().__init__() | |
| self.projects = nn.ModuleList([ | |
| nn.Sequential(nn.LayerNorm(in_dim), nn.Linear(in_dim, out_dim)) | |
| for _ in range(4) | |
| ]) | |
| self.refine = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Conv2d(out_dim, out_dim, 3, padding=1, padding_mode="reflect", bias=False), | |
| nn.BatchNorm2d(out_dim), | |
| nn.GELU(), | |
| ) | |
| for _ in range(4) | |
| ]) | |
| def forward(self, intermediates: List[Tensor], H: int, W: int) -> List[Tensor]: | |
| out = [] | |
| for feat, proj, refine in zip(intermediates, self.projects, self.refine): | |
| patches = feat[:, N_PREFIX_TOKENS:, :] | |
| patches = proj(patches) | |
| B, N, D = patches.shape | |
| spatial = patches.permute(0, 2, 1).reshape(B, D, H, W) | |
| out.append(refine(spatial)) | |
| level_4 = F.interpolate(out[0], scale_factor=4, mode="bilinear", align_corners=False) | |
| level_8 = F.interpolate(out[1], scale_factor=2, mode="bilinear", align_corners=False) | |
| level_16 = out[2] | |
| level_32 = F.interpolate(out[3], scale_factor=0.5, mode="bilinear", align_corners=False) | |
| return [level_4, level_8, level_16, level_32] | |
| class DPTDepthDecoder(nn.Module): | |
| def __init__(self, in_dim: int = 768, decoder_dim: int = 256, | |
| n_bins: int = 256, min_depth: float = 0.001, max_depth: float = 10.0): | |
| super().__init__() | |
| self.n_bins = n_bins | |
| self.min_depth = min_depth | |
| self.max_depth = max_depth | |
| self.reassemble = _DPTReassemble(in_dim=in_dim, out_dim=decoder_dim) | |
| self.fusion_blocks = nn.ModuleList([ | |
| _FeatureFusionBlock(decoder_dim, has_skip=True), | |
| _FeatureFusionBlock(decoder_dim, has_skip=True), | |
| _FeatureFusionBlock(decoder_dim, has_skip=True), | |
| _FeatureFusionBlock(decoder_dim, has_skip=False), | |
| ]) | |
| self.head = nn.Sequential( | |
| nn.Conv2d(decoder_dim, decoder_dim, 3, padding=1, padding_mode="reflect", bias=False), | |
| nn.BatchNorm2d(decoder_dim), | |
| nn.GELU(), | |
| nn.Conv2d(decoder_dim, n_bins, 1), | |
| ) | |
| def forward(self, intermediates: List[Tensor], H: int, W: int, | |
| return_distribution: bool = False): | |
| levels = self.reassemble(intermediates, H, W) | |
| x = self.fusion_blocks[3](levels[3]) | |
| x = self.fusion_blocks[2](x, skip=levels[2]) | |
| x = self.fusion_blocks[1](x, skip=levels[1]) | |
| x = self.fusion_blocks[0](x, skip=levels[0]) | |
| logits = self.head(x) | |
| distribution = torch.relu(logits) + 0.1 | |
| distribution = distribution / distribution.sum(dim=1, keepdim=True) | |
| bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=x.device) | |
| depth = torch.einsum("bkhw,k->bhw", distribution, bins).unsqueeze(1) | |
| if return_distribution: | |
| return depth, distribution, bins | |
| return depth | |
| # =========================================================================== | |
| # Argus model (transformers-compatible) | |
| # =========================================================================== | |
| class ArgusConfig(PretrainedConfig): | |
| model_type = "argus" | |
| def __init__( | |
| self, | |
| embed_dim: int = 768, | |
| patch_size: int = 16, | |
| num_seg_classes: int = 150, | |
| depth_n_bins: int = 256, | |
| depth_min_depth: float = 0.001, | |
| depth_max_depth: float = 10.0, | |
| num_imagenet_classes: int = 1000, | |
| class_ids: Optional[list] = None, | |
| class_names: Optional[list] = None, | |
| detection_num_classes: int = 80, | |
| detection_hidden: int = 160, | |
| detection_n_std_layers: int = 5, | |
| detection_n_dw_layers: int = 4, | |
| detection_n_scales: int = 4, | |
| detection_pos_emb_dim: int = 64, | |
| detection_text_embed_dim: int = 768, | |
| detection_class_names: Optional[list] = None, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.embed_dim = embed_dim | |
| self.patch_size = patch_size | |
| self.num_seg_classes = num_seg_classes | |
| self.depth_n_bins = depth_n_bins | |
| self.depth_min_depth = depth_min_depth | |
| self.depth_max_depth = depth_max_depth | |
| self.num_imagenet_classes = num_imagenet_classes | |
| self.class_ids = class_ids or [] | |
| self.class_names = class_names or [] | |
| self.detection_num_classes = detection_num_classes | |
| self.detection_hidden = detection_hidden | |
| self.detection_n_std_layers = detection_n_std_layers | |
| self.detection_n_dw_layers = detection_n_dw_layers | |
| self.detection_n_scales = detection_n_scales | |
| self.detection_pos_emb_dim = detection_pos_emb_dim | |
| self.detection_text_embed_dim = detection_text_embed_dim | |
| self.detection_class_names = detection_class_names or list(COCO_CLASSES) | |
| class Argus(PreTrainedModel): | |
| config_class = ArgusConfig | |
| base_model_prefix = "argus" | |
| supports_gradient_checkpointing = False | |
| _tied_weights_keys: list = [] | |
| all_tied_weights_keys: dict = {} | |
| def __init__(self, config: ArgusConfig): | |
| super().__init__(config) | |
| self.backbone = build_eupe_vitb16() | |
| self.seg_head = SegmentationHead(config.embed_dim, config.num_seg_classes) | |
| self.depth_head = DPTDepthDecoder( | |
| in_dim=config.embed_dim, | |
| decoder_dim=256, | |
| n_bins=config.depth_n_bins, | |
| min_depth=config.depth_min_depth, | |
| max_depth=config.depth_max_depth, | |
| ) | |
| self.register_buffer( | |
| "class_logit_weight", | |
| torch.zeros(config.num_imagenet_classes, config.embed_dim), | |
| persistent=True, | |
| ) | |
| self.register_buffer( | |
| "class_logit_bias", | |
| torch.zeros(config.num_imagenet_classes), | |
| persistent=True, | |
| ) | |
| self.detection_head = SplitTowerHead( | |
| feat_dim=config.embed_dim, | |
| hidden=config.detection_hidden, | |
| n_std_layers=config.detection_n_std_layers, | |
| n_dw_layers=config.detection_n_dw_layers, | |
| n_scales=config.detection_n_scales, | |
| pos_emb_dim=config.detection_pos_emb_dim, | |
| num_classes=config.detection_num_classes, | |
| text_embed_dim=config.detection_text_embed_dim, | |
| ) | |
| for p in self.backbone.parameters(): | |
| p.requires_grad = False | |
| self.backbone.eval() | |
| self.seg_head.eval() | |
| self.depth_head.eval() | |
| self.detection_head.eval() | |
| def _init_weights(self, module): | |
| # HF reallocates missing buffers and parameters with torch.empty() | |
| # (uninitialized memory) on from_pretrained. Populate sensible defaults | |
| # for the standard layer types used by the detection head, and zero any | |
| # Argus-level buffer that came back NaN. | |
| if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): | |
| nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") | |
| if module.bias is not None: | |
| nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.GroupNorm): | |
| nn.init.ones_(module.weight) | |
| nn.init.zeros_(module.bias) | |
| if module is self: | |
| for name in ("class_logit_weight", "class_logit_bias"): | |
| if hasattr(self, name): | |
| buf = getattr(self, name) | |
| if torch.isnan(buf).any() or torch.isinf(buf).any(): | |
| buf.data.zero_() | |
| def _load_imagenet_classes(self): | |
| if getattr(self, "_imagenet_classes_loaded", False): | |
| return | |
| self._imagenet_classes_loaded = True | |
| import json | |
| import os as _os | |
| candidates = [] | |
| here = _os.path.dirname(_os.path.abspath(__file__)) | |
| candidates.append(_os.path.join(here, "imagenet_classes.json")) | |
| name_or_path = getattr(self.config, "_name_or_path", None) | |
| if name_or_path and _os.path.isdir(name_or_path): | |
| candidates.append(_os.path.join(name_or_path, "imagenet_classes.json")) | |
| for path in candidates: | |
| if _os.path.isfile(path): | |
| with open(path) as f: | |
| data = json.load(f) | |
| self.config.class_ids = data.get("class_ids", []) | |
| self.config.class_names = data.get("class_names", []) | |
| return | |
| if name_or_path and not _os.path.isdir(name_or_path): | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| path = hf_hub_download(name_or_path, "imagenet_classes.json") | |
| with open(path) as f: | |
| data = json.load(f) | |
| self.config.class_ids = data.get("class_ids", []) | |
| self.config.class_names = data.get("class_names", []) | |
| except Exception: | |
| pass | |
| def class_ids(self): | |
| if not self.config.class_ids: | |
| self._load_imagenet_classes() | |
| return self.config.class_ids | |
| def class_names(self): | |
| if not self.config.class_names: | |
| self._load_imagenet_classes() | |
| return self.config.class_names | |
| def quantize_int8(self): | |
| """Apply INT8 weight-only quantization via torchao. Reduces VRAM by ~11% | |
| with negligible accuracy loss (<0.05 m depth drift, 100% classification | |
| agreement). Requires torchao: pip install torchao.""" | |
| try: | |
| from torchao.quantization import quantize_, Int8WeightOnlyConfig | |
| except ImportError as e: | |
| raise ImportError("torchao is required for INT8 quantization: pip install torchao") from e | |
| quantize_(self, Int8WeightOnlyConfig()) | |
| return self | |
| def _extract(self, image_tensor: Tensor) -> Tuple[Tensor, Tensor]: | |
| with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"): | |
| out = self.backbone.forward_features(image_tensor) | |
| cls = out["x_norm_clstoken"].float() | |
| patches = out["x_norm_patchtokens"].float() | |
| B, N, D = patches.shape | |
| h = w = int(N ** 0.5) | |
| spatial = patches.permute(0, 2, 1).reshape(B, D, h, w) | |
| return cls, spatial | |
| def classify(self, image_or_images, top_k: int = 5): | |
| single, images = _normalize_image_input(image_or_images) | |
| transform = make_eupe_transform(224) | |
| batch = torch.stack([transform(img) for img in images]).to(self.device) | |
| cls, _ = self._extract(batch) | |
| cls = F.normalize(cls, dim=-1) | |
| w = self.class_logit_weight.to(cls.dtype) | |
| b = self.class_logit_bias.to(cls.dtype) | |
| logits = F.linear(cls, w, b) | |
| scores_full = F.softmax(logits, dim=-1) | |
| topk = scores_full.topk(top_k, dim=-1) | |
| top2 = scores_full.topk(2, dim=-1) | |
| margins = (top2.values[:, 0] - top2.values[:, 1]).tolist() | |
| results = [] | |
| for b in range(len(images)): | |
| entries = [] | |
| for score, idx in zip(topk.values[b].tolist(), topk.indices[b].tolist()): | |
| entries.append({ | |
| "class_id": self.class_ids[idx], | |
| "class_name": self.class_names[idx], | |
| "score": float(score), | |
| }) | |
| entries[0]["margin"] = float(margins[b]) | |
| results.append(entries) | |
| return results[0] if single else results | |
| def segment(self, image_or_images, resolution: int = 512, return_confidence: bool = False): | |
| single, images = _normalize_image_input(image_or_images) | |
| transform = make_eupe_transform(resolution) | |
| batch = torch.stack([transform(img) for img in images]).to(self.device) | |
| _, spatial = self._extract(batch) | |
| with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"): | |
| logits = self.seg_head(spatial) | |
| logits = F.interpolate(logits, size=(resolution, resolution), mode="bilinear", align_corners=False) | |
| seg_maps = logits.argmax(dim=1) # [B, H, W] | |
| if return_confidence: | |
| probs = F.softmax(logits.float(), dim=1) | |
| conf_maps = probs.max(dim=1).values # [B, H, W] in [0, 1] | |
| if single: | |
| return seg_maps[0], conf_maps[0] | |
| return [(seg_maps[i], conf_maps[i]) for i in range(len(images))] | |
| if single: | |
| return seg_maps[0] | |
| return [seg_maps[i] for i in range(len(images))] | |
| def depth(self, image_or_images, resolution: int = 416, return_confidence: bool = False, | |
| crop_border: bool = False): | |
| """Run the DPT depth decoder. Returns metric depth in meters at the | |
| input resolution. | |
| ``crop_border=True`` strips a small border (``max(4, H/13)`` pixels per | |
| side) from the raw decoder output before bilinear-upsampling to the | |
| input resolution. Useful when this model is loaded with a backbone | |
| whose DPT decoder was trained with zero padding (the unshipped | |
| dev-fork behaviour), which leaves a systematic edge artifact. The | |
| canonical checkpoint uses reflect padding inside every DPT conv and | |
| does not need this crop, so the option defaults to ``False``.""" | |
| single, images = _normalize_image_input(image_or_images) | |
| transform = make_eupe_transform(resolution) | |
| batch = torch.stack([transform(img) for img in images]).to(self.device) | |
| # Hook into intermediate ViT blocks for multi-scale features | |
| intermediates = {} | |
| hooks = [] | |
| for idx in HOOK_BLOCK_INDICES: | |
| def _make_hook(block_idx): | |
| def _hook(module, inp, out): | |
| intermediates[block_idx] = out[0] if isinstance(out, list) else out | |
| return _hook | |
| hooks.append(self.backbone.blocks[idx].register_forward_hook(_make_hook(idx))) | |
| with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"): | |
| self.backbone.forward_features(batch) | |
| for h in hooks: | |
| h.remove() | |
| inter_list = [intermediates[idx].float() for idx in HOOK_BLOCK_INDICES] | |
| H = W = resolution // 16 | |
| if return_confidence: | |
| depth_b, distribution, bins = self.depth_head( | |
| inter_list, H, W, return_distribution=True) | |
| # Std of the 256-bin depth distribution: var = E[X^2] - E[X]^2. | |
| mean_sq = torch.einsum("bkhw,k->bhw", distribution, bins ** 2) | |
| variance = (mean_sq - depth_b.squeeze(1) ** 2).clamp(min=0) | |
| std_b = torch.sqrt(variance).unsqueeze(1) | |
| else: | |
| depth_b = self.depth_head(inter_list, H, W) | |
| std_b = None | |
| if crop_border: | |
| crop = max(4, depth_b.shape[2] // 13) | |
| depth_b = depth_b[:, :, crop:-crop, crop:-crop] | |
| if std_b is not None: | |
| std_b = std_b[:, :, crop:-crop, crop:-crop] | |
| depth_b = F.interpolate(depth_b, size=(resolution, resolution), mode="bilinear", align_corners=False) | |
| if std_b is not None: | |
| std_b = F.interpolate(std_b, size=(resolution, resolution), mode="bilinear", align_corners=False) | |
| depth_squeezed = depth_b[:, 0].float() | |
| if return_confidence: | |
| std_squeezed = std_b[:, 0].float() | |
| if single: | |
| return depth_squeezed[0], std_squeezed[0] | |
| return [(depth_squeezed[i], std_squeezed[i]) for i in range(len(images))] | |
| if single: | |
| return depth_squeezed[0] | |
| return [depth_squeezed[i] for i in range(len(images))] | |
| def correspond( | |
| self, | |
| src_image, | |
| tgt_image, | |
| resolution: int = 512, | |
| ): | |
| """Dense patch correspondence between two images. | |
| Single-pair form: pass two `PIL.Image` instances. Returns a dict with | |
| keys `matches` (numpy array of length grid*grid mapping each source | |
| patch to its argmax target patch), `scores` (cosine similarity at the | |
| match), and `grid` (the patch-grid side length). | |
| Batched form: pass two equally-sized lists/iterables of images. Returns | |
| a list of per-pair dicts in the same shape that a single call would | |
| produce. Both lists are forwarded through the backbone in two | |
| contiguous batches, so cross-pair throughput on GPU is much higher | |
| than calling `correspond` in a loop. | |
| """ | |
| single = isinstance(src_image, Image.Image) and isinstance(tgt_image, Image.Image) | |
| if single: | |
| srcs = [src_image] | |
| tgts = [tgt_image] | |
| else: | |
| srcs = list(src_image) | |
| tgts = list(tgt_image) | |
| if len(srcs) != len(tgts): | |
| raise ValueError( | |
| f"src_image and tgt_image must have the same length; " | |
| f"got {len(srcs)} and {len(tgts)}") | |
| if not srcs: | |
| raise ValueError("empty image list") | |
| for i, (a, b) in enumerate(zip(srcs, tgts)): | |
| if not isinstance(a, Image.Image) or not isinstance(b, Image.Image): | |
| raise TypeError(f"pair {i} must contain two PIL.Image instances") | |
| transform = make_eupe_transform(resolution) | |
| src_batch = torch.stack([transform(img) for img in srcs]).to(self.device) | |
| tgt_batch = torch.stack([transform(img) for img in tgts]).to(self.device) | |
| with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"): | |
| oa = self.backbone.forward_features(src_batch) | |
| ob = self.backbone.forward_features(tgt_batch) | |
| pa_batch = F.normalize(oa['x_norm_patchtokens'].float(), dim=-1) | |
| pb_batch = F.normalize(ob['x_norm_patchtokens'].float(), dim=-1) | |
| results = [] | |
| for pa, pb in zip(pa_batch, pb_batch): | |
| sim = pa @ pb.t() | |
| m = sim.argmax(dim=-1) | |
| s = sim.max(dim=-1).values | |
| grid = int(np.sqrt(pa.shape[0])) | |
| results.append({ | |
| "matches": m.cpu().numpy(), | |
| "scores": s.cpu().numpy(), | |
| "grid": grid, | |
| }) | |
| return results[0] if single else results | |
| def detect( | |
| self, | |
| image_or_images, | |
| resolution: int = 768, | |
| score_thresh: float = 0.05, | |
| nms_thresh: float = 0.5, | |
| max_per_image: int = 100, | |
| ): | |
| single, images = _normalize_image_input(image_or_images) | |
| # Letterbox each image to match the training transform (resize long side | |
| # to `resolution`, pad bottom/right with black). Box coordinates are | |
| # recovered after decoding by unscaling. | |
| canvases, scales, orig_sizes = [], [], [] | |
| for img in images: | |
| canvas, scale, orig = _letterbox_to_square(img, resolution) | |
| canvases.append(canvas) | |
| scales.append(scale) | |
| orig_sizes.append(orig) | |
| det_normalize = v2.Compose([ | |
| v2.ToImage(), | |
| v2.ToDtype(torch.float32, scale=True), | |
| v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ]) | |
| batch = torch.stack([det_normalize(c) for c in canvases]).to(self.device) | |
| _, spatial = self._extract(batch) | |
| with torch.autocast(self.device.type, dtype=torch.bfloat16, enabled=self.device.type == "cuda"): | |
| cls_logits, box_regs, centernesses = self.detection_head(spatial) | |
| cls_logits = [c.float() for c in cls_logits] | |
| box_regs = [b.float() for b in box_regs] | |
| centernesses = [c.float() for c in centernesses] | |
| feature_sizes = [(cl.shape[2], cl.shape[3]) for cl in cls_logits] | |
| locations = _make_locations(feature_sizes, FPN_STRIDES, spatial.device) | |
| image_sizes = [(resolution, resolution)] * len(images) | |
| results = _decode_detections( | |
| cls_logits, box_regs, centernesses, locations, | |
| image_sizes=image_sizes, | |
| score_thresh=score_thresh, | |
| nms_thresh=nms_thresh, | |
| max_per_image=max_per_image, | |
| ) | |
| class_names = self.config.detection_class_names | |
| formatted = [] | |
| for i, r in enumerate(results): | |
| scale = scales[i] | |
| orig_w, orig_h = orig_sizes[i] | |
| boxes = r["boxes"].cpu().numpy() / scale | |
| boxes[:, 0::2] = boxes[:, 0::2].clip(0, orig_w) | |
| boxes[:, 1::2] = boxes[:, 1::2].clip(0, orig_h) | |
| detections = [] | |
| for box, score, label in zip( | |
| boxes, r["scores"].cpu().numpy(), r["labels"].cpu().numpy() | |
| ): | |
| detections.append({ | |
| "box": [float(v) for v in box.tolist()], | |
| "score": float(score), | |
| "label": int(label), | |
| "class_name": class_names[int(label)] if int(label) < len(class_names) else f"class_{int(label)}", | |
| }) | |
| formatted.append(detections) | |
| return formatted[0] if single else formatted | |
| def perceive(self, image_or_images, return_confidence: bool = False): | |
| single, images = _normalize_image_input(image_or_images) | |
| t0 = time.time() | |
| classif = self.classify(images, top_k=5) | |
| t1 = time.time() | |
| seg_out = self.segment(images, resolution=512, return_confidence=return_confidence) | |
| t2 = time.time() | |
| depth_out = self.depth(images, resolution=416, return_confidence=return_confidence) | |
| t3 = time.time() | |
| if return_confidence: | |
| seg_maps = [s for s, _ in seg_out] | |
| seg_confs = [c for _, c in seg_out] | |
| depth_maps = [d for d, _ in depth_out] | |
| depth_uncerts = [u for _, u in depth_out] | |
| else: | |
| seg_maps = seg_out | |
| depth_maps = depth_out | |
| seg_confs = depth_uncerts = None | |
| timings = { | |
| "classify": (t1 - t0) * 1000, | |
| "segment": (t2 - t1) * 1000, | |
| "depth": (t3 - t2) * 1000, | |
| "total": (t3 - t0) * 1000, | |
| } | |
| results = [] | |
| for i in range(len(images)): | |
| entry = { | |
| "classification": classif[i], | |
| "segmentation": seg_maps[i].cpu().numpy(), | |
| "depth": depth_maps[i].cpu().numpy(), | |
| "timings_ms": timings, | |
| } | |
| if return_confidence: | |
| entry["segmentation_confidence"] = seg_confs[i].cpu().numpy() | |
| entry["depth_uncertainty"] = depth_uncerts[i].cpu().numpy() | |
| results.append(entry) | |
| return results[0] if single else results | |
| def export_onnx( | |
| self, | |
| out_dir: str, | |
| backbone_resolution: int = 224, | |
| dynamic_batch: bool = True, | |
| verify: bool = True, | |
| tolerance: Union[float, Dict[str, float]] = 5e-2, | |
| opset_version: int = 17, | |
| include_nms: bool = False, | |
| nms_iou_threshold: float = 0.5, | |
| nms_score_threshold: float = 0.05, | |
| nms_max_detections: int = 100, | |
| ) -> dict: | |
| """Export backbone, classifier, seg head, depth head, and detection head to ONNX. | |
| Produces five graphs: | |
| - argus_backbone.onnx image[B,3,H,W] -> cls[B,D], spatial[B,D,H/16,W/16] | |
| - argus_classifier.onnx cls_token[B,D] -> probs[B,1000] | |
| - argus_seg_head.onnx spatial_features[B,D,h,w] -> seg_logits[B,150,H,W] | |
| - argus_depth_head.onnx intermediate_{0..3}[B,N+5,D] -> depth_map[B,1,~8h,~8w] | |
| - argus_detection_head.onnx spatial_features[B,D,h,w] -> boxes, scores (+ labels, batch_indices if include_nms) | |
| The seg graph folds bilinear upsample to input resolution into the | |
| graph, so consumers argmax directly without a separate interpolation | |
| step. Correspondence has no learned parameters — it runs as | |
| cosine-max on the backbone's spatial output and needs no graph. | |
| ``include_nms=True`` bakes an ONNX NonMaxSuppression (opset >= 10) | |
| op into the detection head. The detection graph then emits four | |
| post-NMS tensors (boxes [M,4], scores [M], class_labels [M], | |
| batch_indices [M]) instead of the raw (boxes, scores) pair. Useful | |
| for single-shot TensorRT / mobile inference. The default | |
| ``include_nms=False`` leaves NMS to the consumer so they can choose | |
| hard vs soft, per-class vs global, and tune thresholds without | |
| re-exporting. | |
| ``tolerance`` can be a float (applied uniformly to every | |
| ``*_max_diff`` check) or a dict keyed by verification output name | |
| (e.g. ``{"detection_boxes_max_diff": 3.2, "default": 5e-2}``). The | |
| ``"default"`` key covers outputs not otherwise listed. If a float | |
| is passed, detection box coordinates get a resolution-scaled | |
| tolerance (``max(tolerance, backbone_resolution * 5e-3)``) because | |
| exp() in the FCOS regression path amplifies FP kernel-dispatch | |
| differences to pixel-scale absolute diffs. | |
| """ | |
| import os | |
| os.makedirs(out_dir, exist_ok=True) | |
| if backbone_resolution % self.config.patch_size != 0: | |
| raise ValueError( | |
| f"backbone_resolution ({backbone_resolution}) must be a multiple of patch_size ({self.config.patch_size})" | |
| ) | |
| spatial_resolution = backbone_resolution // self.config.patch_size | |
| if backbone_resolution < 320: | |
| import warnings | |
| warnings.warn( | |
| f"backbone_resolution={backbone_resolution} is below 320; the detection " | |
| f"head's coarsest FPN level (stride 128) collapses to <=2 locations per " | |
| f"side and the detection graph, while it exports and runs, cannot produce " | |
| f"useful detections at this resolution. Classifier, seg, and depth graphs " | |
| f"are unaffected. FCOS convention is 640-800px input; export at " | |
| f">= 512 for detection.", | |
| stacklevel=2, | |
| ) | |
| wrapper = _BackboneExportWrapper(self.backbone).to(self.device).eval() | |
| dummy_image = torch.randn( | |
| 1, 3, backbone_resolution, backbone_resolution, | |
| device=self.device, dtype=torch.float32, | |
| ) | |
| dummy_spatial = torch.randn( | |
| 1, self.config.embed_dim, spatial_resolution, spatial_resolution, | |
| device=self.device, dtype=torch.float32, | |
| ) | |
| backbone_path = os.path.join(out_dir, "argus_backbone.onnx") | |
| classifier_path = os.path.join(out_dir, "argus_classifier.onnx") | |
| seg_path = os.path.join(out_dir, "argus_seg_head.onnx") | |
| depth_path = os.path.join(out_dir, "argus_depth_head.onnx") | |
| detection_path = os.path.join(out_dir, "argus_detection_head.onnx") | |
| backbone_axes = None | |
| head_axes = None | |
| if dynamic_batch: | |
| backbone_axes = { | |
| "image": {0: "batch"}, | |
| "cls_token": {0: "batch"}, | |
| "spatial_features": {0: "batch"}, | |
| } | |
| head_axes = { | |
| "spatial_features": {0: "batch"}, | |
| "seg_logits": {0: "batch"}, | |
| "depth_map": {0: "batch"}, | |
| } | |
| # dynamo path crashes on EUPE's list-based forward; use legacy. | |
| with torch.inference_mode(): | |
| torch.onnx.export( | |
| wrapper, dummy_image, backbone_path, | |
| input_names=["image"], | |
| output_names=["cls_token", "spatial_features"], | |
| dynamic_axes=backbone_axes, | |
| opset_version=opset_version, | |
| do_constant_folding=True, | |
| dynamo=False, | |
| ) | |
| seg_wrapper = _SegHeadExportWrapper(self.seg_head, backbone_resolution).to(self.device).eval() | |
| torch.onnx.export( | |
| seg_wrapper, dummy_spatial, seg_path, | |
| input_names=["spatial_features"], | |
| output_names=["seg_logits"], | |
| dynamic_axes={"spatial_features": head_axes["spatial_features"], "seg_logits": head_axes["seg_logits"]} if head_axes else None, | |
| opset_version=opset_version, | |
| do_constant_folding=True, | |
| dynamo=False, | |
| ) | |
| depth_wrapper = _DepthHeadExportWrapper( | |
| self.depth_head, spatial_resolution, spatial_resolution | |
| ).to(self.device).eval() | |
| num_patch_tokens = spatial_resolution * spatial_resolution + N_PREFIX_TOKENS | |
| dummy_inter = tuple( | |
| torch.randn(1, num_patch_tokens, self.config.embed_dim, | |
| device=self.device, dtype=torch.float32) | |
| for _ in range(len(HOOK_BLOCK_INDICES)) | |
| ) | |
| depth_input_names = [f"intermediate_{i}" for i in range(len(HOOK_BLOCK_INDICES))] | |
| if dynamic_batch: | |
| depth_axes = {name: {0: "batch"} for name in depth_input_names} | |
| depth_axes["depth_map"] = {0: "batch"} | |
| else: | |
| depth_axes = None | |
| torch.onnx.export( | |
| depth_wrapper, dummy_inter, depth_path, | |
| input_names=depth_input_names, | |
| output_names=["depth_map"], | |
| dynamic_axes=depth_axes, | |
| opset_version=opset_version, | |
| do_constant_folding=True, | |
| dynamo=False, | |
| ) | |
| classifier_wrapper = _ClassifierExportWrapper( | |
| self.class_logit_weight, self.class_logit_bias | |
| ).to(self.device).eval() | |
| dummy_cls = torch.randn( | |
| 1, self.config.embed_dim, device=self.device, dtype=torch.float32, | |
| ) | |
| if dynamic_batch: | |
| classifier_axes = {"cls_token": {0: "batch"}, "class_probs": {0: "batch"}} | |
| else: | |
| classifier_axes = None | |
| torch.onnx.export( | |
| classifier_wrapper, dummy_cls, classifier_path, | |
| input_names=["cls_token"], | |
| output_names=["class_probs"], | |
| dynamic_axes=classifier_axes, | |
| opset_version=opset_version, | |
| do_constant_folding=True, | |
| dynamo=False, | |
| ) | |
| detection_wrapper = _DetectionHeadExportWrapper( | |
| self.detection_head, backbone_resolution, | |
| include_nms=include_nms, | |
| nms_iou_threshold=nms_iou_threshold, | |
| nms_score_threshold=nms_score_threshold, | |
| nms_max_detections=nms_max_detections, | |
| ).to(self.device).eval() | |
| if include_nms: | |
| detection_output_names = ["boxes", "scores", "class_labels", "batch_indices"] | |
| # Post-NMS outputs are flat [M, ...]; no fixed batch axis to mark. | |
| # Spatial features input still has a dynamic batch dim so the graph | |
| # supports multi-image inference even with fused NMS. | |
| detection_axes = {"spatial_features": {0: "batch"}} if dynamic_batch else None | |
| else: | |
| detection_output_names = ["boxes", "scores"] | |
| if dynamic_batch: | |
| detection_axes = { | |
| "spatial_features": {0: "batch"}, | |
| "boxes": {0: "batch"}, | |
| "scores": {0: "batch"}, | |
| } | |
| else: | |
| detection_axes = None | |
| torch.onnx.export( | |
| detection_wrapper, dummy_spatial, detection_path, | |
| input_names=["spatial_features"], | |
| output_names=detection_output_names, | |
| dynamic_axes=detection_axes, | |
| opset_version=opset_version, | |
| do_constant_folding=True, | |
| dynamo=False, | |
| ) | |
| result = { | |
| "backbone": backbone_path, | |
| "classifier": classifier_path, | |
| "seg_head": seg_path, | |
| "depth_head": depth_path, | |
| "detection_head": detection_path, | |
| } | |
| if verify: | |
| try: | |
| import onnxruntime as ort | |
| except ImportError as e: | |
| raise ImportError("onnxruntime not installed; pip install onnxruntime") from e | |
| providers = ["CPUExecutionProvider"] | |
| verify_image = torch.randn(2, 3, backbone_resolution, backbone_resolution, dtype=torch.float32) | |
| verify_spatial = torch.randn(2, self.config.embed_dim, spatial_resolution, spatial_resolution, dtype=torch.float32) | |
| verify_cls = torch.randn(2, self.config.embed_dim, dtype=torch.float32) | |
| verify_inter = [ | |
| torch.randn(2, num_patch_tokens, self.config.embed_dim, dtype=torch.float32) | |
| for _ in range(len(HOOK_BLOCK_INDICES)) | |
| ] | |
| with torch.inference_mode(): | |
| ref_cls, ref_spatial = wrapper(verify_image.to(self.device)) | |
| ref_seg = seg_wrapper(verify_spatial.to(self.device)) | |
| ref_depth = depth_wrapper(*[v.to(self.device) for v in verify_inter]) | |
| ref_probs = classifier_wrapper(verify_cls.to(self.device)) | |
| ref_det = detection_wrapper(verify_spatial.to(self.device)) | |
| sess = ort.InferenceSession(backbone_path, providers=providers) | |
| ort_cls, ort_spatial = sess.run(None, {"image": verify_image.numpy()}) | |
| cls_diff = float(np.abs(ort_cls - ref_cls.cpu().numpy()).max()) | |
| spatial_diff = float(np.abs(ort_spatial - ref_spatial.cpu().numpy()).max()) | |
| sess = ort.InferenceSession(seg_path, providers=providers) | |
| ort_seg = sess.run(None, {"spatial_features": verify_spatial.numpy()})[0] | |
| seg_diff = float(np.abs(ort_seg - ref_seg.cpu().numpy()).max()) | |
| sess = ort.InferenceSession(depth_path, providers=providers) | |
| ort_depth = sess.run(None, {f"intermediate_{i}": verify_inter[i].numpy() | |
| for i in range(len(HOOK_BLOCK_INDICES))})[0] | |
| depth_diff = float(np.abs(ort_depth - ref_depth.cpu().numpy()).max()) | |
| sess = ort.InferenceSession(classifier_path, providers=providers) | |
| ort_probs = sess.run(None, {"cls_token": verify_cls.numpy()})[0] | |
| classifier_diff = float(np.abs(ort_probs - ref_probs.cpu().numpy()).max()) | |
| sess = ort.InferenceSession(detection_path, providers=providers) | |
| ort_det = sess.run(None, {"spatial_features": verify_spatial.numpy()}) | |
| verification = { | |
| "backbone_cls_max_diff": cls_diff, | |
| "backbone_spatial_max_diff": spatial_diff, | |
| "classifier_max_diff": classifier_diff, | |
| "seg_head_max_diff": seg_diff, | |
| "depth_head_max_diff": depth_diff, | |
| "verified_batch_size": 2, | |
| } | |
| if include_nms: | |
| # NMS is inherently implementation-dependent: ONNX's | |
| # NonMaxSuppression and the torchvision eager fallback differ | |
| # on tie-breaking when multiple detections share a score or | |
| # when near-threshold boxes are right at the score cutoff. | |
| # Element-wise comparison of post-NMS outputs is the wrong | |
| # metric. The structural checks below verify the graph runs, | |
| # returns reasonable shapes, and agrees on the top detection. | |
| pt_boxes, pt_scores, pt_labels, _ = ref_det | |
| ort_boxes, ort_scores, ort_labels, _ = ort_det | |
| pt_n = int(pt_scores.shape[0]) | |
| ort_n = int(ort_scores.shape[0]) | |
| verification["detection_nms_ref_count"] = pt_n | |
| verification["detection_nms_ort_count"] = ort_n | |
| if pt_n > 0 and ort_n > 0: | |
| pt_top = int(pt_scores.cpu().numpy().argmax()) | |
| ort_top = int(ort_scores.argmax()) | |
| pt_top_box = pt_boxes[pt_top].cpu().numpy() | |
| ort_top_box = ort_boxes[ort_top] | |
| # IoU of the two top boxes | |
| x1 = max(pt_top_box[0], ort_top_box[0]) | |
| y1 = max(pt_top_box[1], ort_top_box[1]) | |
| x2 = min(pt_top_box[2], ort_top_box[2]) | |
| y2 = min(pt_top_box[3], ort_top_box[3]) | |
| inter = max(0.0, x2 - x1) * max(0.0, y2 - y1) | |
| pt_area = max(0.0, pt_top_box[2] - pt_top_box[0]) * max(0.0, pt_top_box[3] - pt_top_box[1]) | |
| ort_area = max(0.0, ort_top_box[2] - ort_top_box[0]) * max(0.0, ort_top_box[3] - ort_top_box[1]) | |
| union = max(1e-6, pt_area + ort_area - inter) | |
| verification["detection_nms_top_iou"] = float(inter / union) | |
| verification["detection_nms_top_class_match"] = bool( | |
| int(pt_labels[pt_top].cpu()) == int(ort_labels[ort_top]) | |
| ) | |
| verification["detection_nms_top_score_diff"] = float(abs( | |
| float(pt_scores[pt_top].cpu()) - float(ort_scores[ort_top]) | |
| )) | |
| else: | |
| verification["detection_nms_top_iou"] = None | |
| verification["detection_nms_top_class_match"] = None | |
| verification["detection_nms_top_score_diff"] = None | |
| else: | |
| ort_boxes, ort_scores = ort_det | |
| ref_boxes, ref_scores = ref_det | |
| verification["detection_boxes_max_diff"] = float( | |
| np.abs(ort_boxes - ref_boxes.cpu().numpy()).max()) | |
| verification["detection_scores_max_diff"] = float( | |
| np.abs(ort_scores - ref_scores.cpu().numpy()).max()) | |
| # Tolerance resolution: either a float applied uniformly, or a dict | |
| # keyed by verification output name (with optional "default" key). | |
| # Detection boxes get a resolution-scaled tolerance when only a | |
| # float is supplied — exp() in the FCOS regression path amplifies | |
| # FP kernel-dispatch differences to pixel-scale absolute diffs. | |
| if isinstance(tolerance, dict): | |
| default_tol = float(tolerance.get("default", 5e-2)) | |
| def _tol_for(key): | |
| return float(tolerance.get(key, default_tol)) | |
| verification["tolerance"] = dict(tolerance) | |
| else: | |
| base = float(tolerance) | |
| box_tol = max(base, backbone_resolution * 5e-3) | |
| def _tol_for(key): | |
| return box_tol if key == "detection_boxes_max_diff" else base | |
| verification["tolerance"] = base | |
| verification["detection_boxes_tolerance"] = box_tol | |
| for key, val in list(verification.items()): | |
| if not key.endswith("_max_diff"): | |
| continue | |
| t = _tol_for(key) | |
| if val > t: | |
| raise RuntimeError( | |
| f"ONNX/PyTorch divergence in {key}: {val:.2e} > tolerance {t:.2e}" | |
| ) | |
| result["verification"] = verification | |
| return result | |