argus.py: reflect-padded DPT convs, drop post-hoc border crop
Browse files
argus.py
CHANGED
|
@@ -1429,9 +1429,9 @@ N_PREFIX_TOKENS = 5 # 1 CLS + 4 register/storage tokens
|
|
| 1429 |
class _ResidualConvUnit(nn.Module):
|
| 1430 |
def __init__(self, dim: int):
|
| 1431 |
super().__init__()
|
| 1432 |
-
self.conv1 = nn.Conv2d(dim, dim, 3, padding=1, bias=False)
|
| 1433 |
self.bn1 = nn.BatchNorm2d(dim)
|
| 1434 |
-
self.conv2 = nn.Conv2d(dim, dim, 3, padding=1, bias=False)
|
| 1435 |
self.bn2 = nn.BatchNorm2d(dim)
|
| 1436 |
self.act = nn.GELU()
|
| 1437 |
|
|
@@ -1463,7 +1463,7 @@ class _DPTReassemble(nn.Module):
|
|
| 1463 |
])
|
| 1464 |
self.refine = nn.ModuleList([
|
| 1465 |
nn.Sequential(
|
| 1466 |
-
nn.Conv2d(out_dim, out_dim, 3, padding=1, bias=False),
|
| 1467 |
nn.BatchNorm2d(out_dim),
|
| 1468 |
nn.GELU(),
|
| 1469 |
)
|
|
@@ -1502,7 +1502,7 @@ class DPTDepthDecoder(nn.Module):
|
|
| 1502 |
_FeatureFusionBlock(decoder_dim, has_skip=False),
|
| 1503 |
])
|
| 1504 |
self.head = nn.Sequential(
|
| 1505 |
-
nn.Conv2d(decoder_dim, decoder_dim, 3, padding=1, bias=False),
|
| 1506 |
nn.BatchNorm2d(decoder_dim),
|
| 1507 |
nn.GELU(),
|
| 1508 |
nn.Conv2d(decoder_dim, n_bins, 1),
|
|
@@ -1790,13 +1790,8 @@ class Argus(PreTrainedModel):
|
|
| 1790 |
depth_b = self.depth_head(inter_list, H, W)
|
| 1791 |
std_b = None
|
| 1792 |
|
| 1793 |
-
# Crop the DPT fusion border artifact (zero-padding in the conv chain
|
| 1794 |
-
# produces systematically wrong edge values that compound across 4 stages)
|
| 1795 |
-
crop = max(4, depth_b.shape[2] // 13)
|
| 1796 |
-
depth_b = depth_b[:, :, crop:-crop, crop:-crop]
|
| 1797 |
depth_b = F.interpolate(depth_b, size=(resolution, resolution), mode="bilinear", align_corners=False)
|
| 1798 |
if std_b is not None:
|
| 1799 |
-
std_b = std_b[:, :, crop:-crop, crop:-crop]
|
| 1800 |
std_b = F.interpolate(std_b, size=(resolution, resolution), mode="bilinear", align_corners=False)
|
| 1801 |
|
| 1802 |
depth_squeezed = depth_b[:, 0].float()
|
|
|
|
| 1429 |
class _ResidualConvUnit(nn.Module):
|
| 1430 |
def __init__(self, dim: int):
|
| 1431 |
super().__init__()
|
| 1432 |
+
self.conv1 = nn.Conv2d(dim, dim, 3, padding=1, padding_mode="reflect", bias=False)
|
| 1433 |
self.bn1 = nn.BatchNorm2d(dim)
|
| 1434 |
+
self.conv2 = nn.Conv2d(dim, dim, 3, padding=1, padding_mode="reflect", bias=False)
|
| 1435 |
self.bn2 = nn.BatchNorm2d(dim)
|
| 1436 |
self.act = nn.GELU()
|
| 1437 |
|
|
|
|
| 1463 |
])
|
| 1464 |
self.refine = nn.ModuleList([
|
| 1465 |
nn.Sequential(
|
| 1466 |
+
nn.Conv2d(out_dim, out_dim, 3, padding=1, padding_mode="reflect", bias=False),
|
| 1467 |
nn.BatchNorm2d(out_dim),
|
| 1468 |
nn.GELU(),
|
| 1469 |
)
|
|
|
|
| 1502 |
_FeatureFusionBlock(decoder_dim, has_skip=False),
|
| 1503 |
])
|
| 1504 |
self.head = nn.Sequential(
|
| 1505 |
+
nn.Conv2d(decoder_dim, decoder_dim, 3, padding=1, padding_mode="reflect", bias=False),
|
| 1506 |
nn.BatchNorm2d(decoder_dim),
|
| 1507 |
nn.GELU(),
|
| 1508 |
nn.Conv2d(decoder_dim, n_bins, 1),
|
|
|
|
| 1790 |
depth_b = self.depth_head(inter_list, H, W)
|
| 1791 |
std_b = None
|
| 1792 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1793 |
depth_b = F.interpolate(depth_b, size=(resolution, resolution), mode="bilinear", align_corners=False)
|
| 1794 |
if std_b is not None:
|
|
|
|
| 1795 |
std_b = F.interpolate(std_b, size=(resolution, resolution), mode="bilinear", align_corners=False)
|
| 1796 |
|
| 1797 |
depth_squeezed = depth_b[:, 0].float()
|