Spaces:
Runtime error
Runtime error
| # Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # modified from DUSt3R | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from dust3r.heads.postprocess import ( | |
| postprocess, | |
| postprocess_desc, | |
| postprocess_rgb, | |
| postprocess_pose_conf, | |
| postprocess_pose, | |
| reg_dense_conf, | |
| ) | |
| import dust3r.utils.path_to_croco # noqa | |
| from models.blocks import Mlp # noqa | |
| from dust3r.utils.geometry import geotrf | |
| from dust3r.utils.camera import pose_encoding_to_camera, PoseDecoder | |
| from dust3r.blocks import ConditionModulationBlock | |
| class LinearPts3d(nn.Module): | |
| """ | |
| Linear head for dust3r | |
| Each token outputs: - 16x16 3D points (+ confidence) | |
| """ | |
| def __init__( | |
| self, net, has_conf=False, has_depth=False, has_rgb=False, has_pose_conf=False | |
| ): | |
| super().__init__() | |
| self.patch_size = net.patch_embed.patch_size[0] | |
| self.depth_mode = net.depth_mode | |
| self.conf_mode = net.conf_mode | |
| self.has_conf = has_conf | |
| self.has_rgb = has_rgb | |
| self.has_pose_conf = has_pose_conf | |
| self.has_depth = has_depth | |
| self.proj = Mlp( | |
| net.dec_embed_dim, out_features=(3 + has_conf) * self.patch_size**2 | |
| ) | |
| if has_depth: | |
| self.self_proj = Mlp( | |
| net.dec_embed_dim, out_features=(3 + has_conf) * self.patch_size**2 | |
| ) | |
| if has_rgb: | |
| self.rgb_proj = Mlp(net.dec_embed_dim, out_features=3 * self.patch_size**2) | |
| def setup(self, croconet): | |
| pass | |
| def forward(self, decout, img_shape): | |
| H, W = img_shape | |
| tokens = decout[-1] | |
| B, S, D = tokens.shape | |
| feat = self.proj(tokens) # B,S,D | |
| feat = feat.transpose(-1, -2).view( | |
| B, -1, H // self.patch_size, W // self.patch_size | |
| ) | |
| feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W | |
| final_output = postprocess(feat, self.depth_mode, self.conf_mode) | |
| final_output["pts3d_in_other_view"] = final_output.pop("pts3d") | |
| if self.has_depth: | |
| self_feat = self.self_proj(tokens) # B,S,D | |
| self_feat = self_feat.transpose(-1, -2).view( | |
| B, -1, H // self.patch_size, W // self.patch_size | |
| ) | |
| self_feat = F.pixel_shuffle(self_feat, self.patch_size) # B,3,H,W | |
| self_3d_output = postprocess(self_feat, self.depth_mode, self.conf_mode) | |
| self_3d_output["pts3d_in_self_view"] = self_3d_output.pop("pts3d") | |
| self_3d_output["conf_self"] = self_3d_output.pop("conf") | |
| final_output.update(self_3d_output) | |
| if self.has_rgb: | |
| rgb_feat = self.rgb_proj(tokens) | |
| rgb_feat = rgb_feat.transpose(-1, -2).view( | |
| B, -1, H // self.patch_size, W // self.patch_size | |
| ) | |
| rgb_feat = F.pixel_shuffle(rgb_feat, self.patch_size) # B,3,H,W | |
| rgb_output = postprocess_rgb(rgb_feat) | |
| final_output.update(rgb_output) | |
| if self.has_pose_conf: | |
| pose_conf = self.pose_conf_proj(tokens) | |
| pose_conf = pose_conf.transpose(-1, -2).view( | |
| B, -1, H // self.patch_size, W // self.patch_size | |
| ) | |
| pose_conf = F.pixel_shuffle(pose_conf, self.patch_size) | |
| pose_conf_output = postprocess_pose_conf(pose_conf) | |
| final_output.update(pose_conf_output) | |
| return final_output | |
| class LinearPts3d_Desc(nn.Module): | |
| """ | |
| Linear head for dust3r | |
| Each token outputs: - 16x16 3D points (+ confidence) | |
| """ | |
| def __init__( | |
| self, | |
| net, | |
| has_conf=False, | |
| has_depth=False, | |
| local_feat_dim=24, | |
| hidden_dim_factor=4.0, | |
| ): | |
| super().__init__() | |
| self.patch_size = net.patch_embed.patch_size[0] | |
| self.depth_mode = net.depth_mode | |
| self.conf_mode = net.conf_mode | |
| self.has_conf = has_conf | |
| self.double_channel = has_depth | |
| self.local_feat_dim = local_feat_dim | |
| if not has_depth: | |
| self.proj = nn.Linear( | |
| net.dec_embed_dim, (3 + has_conf) * self.patch_size**2 | |
| ) | |
| else: | |
| self.proj = nn.Linear( | |
| net.dec_embed_dim, (3 + has_conf) * 2 * self.patch_size**2 | |
| ) | |
| idim = net.enc_embed_dim + net.dec_embed_dim | |
| self.head_local_features = Mlp( | |
| in_features=idim, | |
| hidden_features=int(hidden_dim_factor * idim), | |
| out_features=(self.local_feat_dim + 1) * self.patch_size**2, | |
| ) | |
| def setup(self, croconet): | |
| pass | |
| def forward(self, decout, img_shape): | |
| H, W = img_shape | |
| tokens = decout[-1] | |
| B, S, D = tokens.shape | |
| feat = self.proj(tokens) # B,S,D | |
| feat = feat.transpose(-1, -2).view( | |
| B, -1, H // self.patch_size, W // self.patch_size | |
| ) | |
| feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W | |
| enc_output, dec_output = decout[0], decout[-1] | |
| cat_output = torch.cat([enc_output, dec_output], dim=-1) | |
| local_features = self.head_local_features(cat_output) # B,S,D | |
| local_features = local_features.transpose(-1, -2).view( | |
| B, -1, H // self.patch_size, W // self.patch_size | |
| ) | |
| local_features = F.pixel_shuffle(local_features, self.patch_size) # B,d,H,W | |
| feat = torch.cat([feat, local_features], dim=1) | |
| return postprocess_desc( | |
| feat, | |
| self.depth_mode, | |
| self.conf_mode, | |
| self.local_feat_dim, | |
| self.double_channel, | |
| ) | |
| class LinearPts3dPoseDirect(nn.Module): | |
| """ | |
| Linear head for dust3r | |
| Each token outputs: - 16x16 3D points (+ confidence) | |
| """ | |
| def __init__(self, net, has_conf=False, has_rgb=False, has_pose=False): | |
| super().__init__() | |
| self.patch_size = net.patch_embed.patch_size[0] | |
| self.depth_mode = net.depth_mode | |
| self.conf_mode = net.conf_mode | |
| self.pose_mode = net.pose_mode | |
| self.has_conf = has_conf | |
| self.has_rgb = has_rgb | |
| self.has_pose = has_pose | |
| self.proj = Mlp( | |
| net.dec_embed_dim, out_features=(3 + has_conf) * self.patch_size**2 | |
| ) | |
| if has_rgb: | |
| self.rgb_proj = Mlp(net.dec_embed_dim, out_features=3 * self.patch_size**2) | |
| if has_pose: | |
| self.pose_head = PoseDecoder(hidden_size=net.dec_embed_dim) | |
| if has_conf: | |
| self.cross_conf_proj = Mlp( | |
| net.dec_embed_dim, out_features=self.patch_size**2 | |
| ) | |
| def setup(self, croconet): | |
| pass | |
| def forward(self, decout, img_shape): | |
| H, W = img_shape | |
| tokens = decout[-1] | |
| if self.has_pose: | |
| pose_token = tokens[:, 0] | |
| tokens = tokens[:, 1:] | |
| B, S, D = tokens.shape | |
| feat = self.proj(tokens) # B,S,D | |
| feat = feat.transpose(-1, -2).view( | |
| B, -1, H // self.patch_size, W // self.patch_size | |
| ) | |
| feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W | |
| final_output = postprocess(feat, self.depth_mode, self.conf_mode) | |
| final_output["pts3d_in_self_view"] = final_output.pop("pts3d") | |
| final_output["conf_self"] = final_output.pop("conf") | |
| if self.has_rgb: | |
| rgb_feat = self.rgb_proj(tokens) | |
| rgb_feat = rgb_feat.transpose(-1, -2).view( | |
| B, -1, H // self.patch_size, W // self.patch_size | |
| ) | |
| rgb_feat = F.pixel_shuffle(rgb_feat, self.patch_size) # B,3,H,W | |
| rgb_output = postprocess_rgb(rgb_feat) | |
| final_output.update(rgb_output) | |
| if self.has_pose: | |
| pose = self.pose_head(pose_token) | |
| pose = postprocess_pose(pose, self.pose_mode) | |
| final_output["camera_pose"] = pose # B,7 | |
| final_output["pts3d_in_other_view"] = geotrf( | |
| pose_encoding_to_camera(final_output["camera_pose"]), | |
| final_output["pts3d_in_self_view"], | |
| ) | |
| if self.has_conf: | |
| cross_conf = self.cross_conf_proj(tokens) | |
| cross_conf = cross_conf.transpose(-1, -2).view( | |
| B, -1, H // self.patch_size, W // self.patch_size | |
| ) | |
| cross_conf = F.pixel_shuffle(cross_conf, self.patch_size)[:, 0] | |
| final_output["conf"] = reg_dense_conf(cross_conf, mode=self.conf_mode) | |
| return final_output | |
| class LinearPts3dPose(nn.Module): | |
| """ | |
| Linear head for dust3r | |
| Each token outputs: - 16x16 3D points (+ confidence) | |
| """ | |
| def __init__( | |
| self, net, has_conf=False, has_rgb=False, has_pose=False, mlp_ratio=4.0 | |
| ): | |
| super().__init__() | |
| self.patch_size = net.patch_embed.patch_size[0] | |
| self.depth_mode = net.depth_mode | |
| self.conf_mode = net.conf_mode | |
| self.pose_mode = net.pose_mode | |
| self.has_conf = has_conf | |
| self.has_rgb = has_rgb | |
| self.has_pose = has_pose | |
| self.proj = Mlp( | |
| net.dec_embed_dim, | |
| hidden_features=int(mlp_ratio * net.dec_embed_dim), | |
| out_features=(3 + has_conf) * self.patch_size**2, | |
| ) | |
| if has_rgb: | |
| self.rgb_proj = Mlp( | |
| net.dec_embed_dim, | |
| hidden_features=int(mlp_ratio * net.dec_embed_dim), | |
| out_features=3 * self.patch_size**2, | |
| ) | |
| if has_pose: | |
| self.pose_head = PoseDecoder(hidden_size=net.dec_embed_dim) | |
| self.final_transform = nn.ModuleList( | |
| [ | |
| ConditionModulationBlock( | |
| net.dec_embed_dim, | |
| net.dec_num_heads, | |
| mlp_ratio=4.0, | |
| qkv_bias=True, | |
| rope=net.rope, | |
| ) | |
| for _ in range(2) | |
| ] | |
| ) | |
| self.cross_proj = Mlp( | |
| net.dec_embed_dim, | |
| hidden_features=int(mlp_ratio * net.dec_embed_dim), | |
| out_features=(3 + has_conf) * self.patch_size**2, | |
| ) | |
| def setup(self, croconet): | |
| pass | |
| def forward(self, decout, img_shape, **kwargs): | |
| H, W = img_shape | |
| tokens = decout[-1] | |
| if self.has_pose: | |
| pose_token = tokens[:, 0] | |
| tokens = tokens[:, 1:] | |
| with torch.cuda.amp.autocast(enabled=False): | |
| pose = self.pose_head(pose_token) | |
| cross_tokens = tokens | |
| for blk in self.final_transform: | |
| cross_tokens = blk(cross_tokens, pose_token, kwargs.get("pos")) | |
| with torch.cuda.amp.autocast(enabled=False): | |
| B, S, D = tokens.shape | |
| feat = self.proj(tokens) # B,S,D | |
| feat = feat.transpose(-1, -2).view( | |
| B, -1, H // self.patch_size, W // self.patch_size | |
| ) | |
| feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W | |
| final_output = postprocess( | |
| feat, self.depth_mode, self.conf_mode, pos_z=True | |
| ) | |
| final_output["pts3d_in_self_view"] = final_output.pop("pts3d") | |
| final_output["conf_self"] = final_output.pop("conf") | |
| if self.has_rgb: | |
| rgb_feat = self.rgb_proj(tokens) | |
| rgb_feat = rgb_feat.transpose(-1, -2).view( | |
| B, -1, H // self.patch_size, W // self.patch_size | |
| ) | |
| rgb_feat = F.pixel_shuffle(rgb_feat, self.patch_size) # B,3,H,W | |
| rgb_output = postprocess_rgb(rgb_feat) | |
| final_output.update(rgb_output) | |
| if self.has_pose: | |
| pose = postprocess_pose(pose, self.pose_mode) | |
| final_output["camera_pose"] = pose # B,7 | |
| cross_feat = self.cross_proj(cross_tokens) # B,S,D | |
| cross_feat = cross_feat.transpose(-1, -2).view( | |
| B, -1, H // self.patch_size, W // self.patch_size | |
| ) | |
| cross_feat = F.pixel_shuffle(cross_feat, self.patch_size) # B,3,H,W | |
| tmp = postprocess(cross_feat, self.depth_mode, self.conf_mode) | |
| final_output["pts3d_in_other_view"] = tmp.pop("pts3d") | |
| final_output["conf"] = tmp.pop("conf") | |
| return final_output | |