phanerozoic commited on
Commit
3fb670e
·
verified ·
1 Parent(s): 23d2b4d

argus.py: reflect-padded DPT convs, drop post-hoc border crop

Browse files
Files changed (1) hide show
  1. argus.py +4 -9
argus.py CHANGED
@@ -1429,9 +1429,9 @@ N_PREFIX_TOKENS = 5 # 1 CLS + 4 register/storage tokens
1429
  class _ResidualConvUnit(nn.Module):
1430
  def __init__(self, dim: int):
1431
  super().__init__()
1432
- self.conv1 = nn.Conv2d(dim, dim, 3, padding=1, bias=False)
1433
  self.bn1 = nn.BatchNorm2d(dim)
1434
- self.conv2 = nn.Conv2d(dim, dim, 3, padding=1, bias=False)
1435
  self.bn2 = nn.BatchNorm2d(dim)
1436
  self.act = nn.GELU()
1437
 
@@ -1463,7 +1463,7 @@ class _DPTReassemble(nn.Module):
1463
  ])
1464
  self.refine = nn.ModuleList([
1465
  nn.Sequential(
1466
- nn.Conv2d(out_dim, out_dim, 3, padding=1, bias=False),
1467
  nn.BatchNorm2d(out_dim),
1468
  nn.GELU(),
1469
  )
@@ -1502,7 +1502,7 @@ class DPTDepthDecoder(nn.Module):
1502
  _FeatureFusionBlock(decoder_dim, has_skip=False),
1503
  ])
1504
  self.head = nn.Sequential(
1505
- nn.Conv2d(decoder_dim, decoder_dim, 3, padding=1, bias=False),
1506
  nn.BatchNorm2d(decoder_dim),
1507
  nn.GELU(),
1508
  nn.Conv2d(decoder_dim, n_bins, 1),
@@ -1790,13 +1790,8 @@ class Argus(PreTrainedModel):
1790
  depth_b = self.depth_head(inter_list, H, W)
1791
  std_b = None
1792
 
1793
- # Crop the DPT fusion border artifact (zero-padding in the conv chain
1794
- # produces systematically wrong edge values that compound across 4 stages)
1795
- crop = max(4, depth_b.shape[2] // 13)
1796
- depth_b = depth_b[:, :, crop:-crop, crop:-crop]
1797
  depth_b = F.interpolate(depth_b, size=(resolution, resolution), mode="bilinear", align_corners=False)
1798
  if std_b is not None:
1799
- std_b = std_b[:, :, crop:-crop, crop:-crop]
1800
  std_b = F.interpolate(std_b, size=(resolution, resolution), mode="bilinear", align_corners=False)
1801
 
1802
  depth_squeezed = depth_b[:, 0].float()
 
1429
  class _ResidualConvUnit(nn.Module):
1430
  def __init__(self, dim: int):
1431
  super().__init__()
1432
+ self.conv1 = nn.Conv2d(dim, dim, 3, padding=1, padding_mode="reflect", bias=False)
1433
  self.bn1 = nn.BatchNorm2d(dim)
1434
+ self.conv2 = nn.Conv2d(dim, dim, 3, padding=1, padding_mode="reflect", bias=False)
1435
  self.bn2 = nn.BatchNorm2d(dim)
1436
  self.act = nn.GELU()
1437
 
 
1463
  ])
1464
  self.refine = nn.ModuleList([
1465
  nn.Sequential(
1466
+ nn.Conv2d(out_dim, out_dim, 3, padding=1, padding_mode="reflect", bias=False),
1467
  nn.BatchNorm2d(out_dim),
1468
  nn.GELU(),
1469
  )
 
1502
  _FeatureFusionBlock(decoder_dim, has_skip=False),
1503
  ])
1504
  self.head = nn.Sequential(
1505
+ nn.Conv2d(decoder_dim, decoder_dim, 3, padding=1, padding_mode="reflect", bias=False),
1506
  nn.BatchNorm2d(decoder_dim),
1507
  nn.GELU(),
1508
  nn.Conv2d(decoder_dim, n_bins, 1),
 
1790
  depth_b = self.depth_head(inter_list, H, W)
1791
  std_b = None
1792
 
 
 
 
 
1793
  depth_b = F.interpolate(depth_b, size=(resolution, resolution), mode="bilinear", align_corners=False)
1794
  if std_b is not None:
 
1795
  std_b = F.interpolate(std_b, size=(resolution, resolution), mode="bilinear", align_corners=False)
1796
 
1797
  depth_squeezed = depth_b[:, 0].float()