Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,379 Bytes
28c1b3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin # used for model hub
from streamvggt.models.aggregator import Aggregator
from streamvggt.heads.camera_head import CameraHead
from streamvggt.heads.dpt_head import DPTHead
from streamvggt.heads.track_head import TrackHead
from transformers.file_utils import ModelOutput
from typing import Optional, Tuple, List, Any
from dataclasses import dataclass
@dataclass
class StreamVGGTOutput(ModelOutput):
ress: Optional[List[dict]] = None
views: Optional[torch.Tensor] = None
class StreamVGGT(nn.Module, PyTorchModelHubMixin):
def __init__(self, img_size=518, patch_size=14, embed_dim=1024):
super().__init__()
self.aggregator = Aggregator(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
self.camera_head = CameraHead(dim_in=2 * embed_dim)
self.point_head = DPTHead(dim_in=2 * embed_dim, output_dim=4, activation="inv_log", conf_activation="expp1")
self.depth_head = DPTHead(dim_in=2 * embed_dim, output_dim=2, activation="exp", conf_activation="expp1")
self.track_head = TrackHead(dim_in=2 * embed_dim, patch_size=patch_size)
def forward(
self,
views,
query_points: torch.Tensor = None,
history_info: Optional[dict] = None,
past_key_values=None,
use_cache=False,
past_frame_idx=0
):
images = torch.stack(
[view["img"] for view in views], dim=0
).permute(1, 0, 2, 3, 4) # B S C H W
# If without batch dimension, add it
if len(images.shape) == 4:
images = images.unsqueeze(0)
if query_points is not None and len(query_points.shape) == 2:
query_points = query_points.unsqueeze(0)
if history_info is None:
history_info = {"token": None}
aggregated_tokens_list, patch_start_idx = self.aggregator(images)
predictions = {}
with torch.cuda.amp.autocast(enabled=False):
if self.camera_head is not None:
pose_enc_list = self.camera_head(aggregated_tokens_list)
predictions["pose_enc"] = pose_enc_list[-1] # pose encoding of the last iteration
if self.depth_head is not None:
depth, depth_conf = self.depth_head(
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
)
predictions["depth"] = depth
predictions["depth_conf"] = depth_conf
if self.point_head is not None:
pts3d, pts3d_conf = self.point_head(
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx
)
predictions["world_points"] = pts3d
predictions["world_points_conf"] = pts3d_conf
if self.track_head is not None and query_points is not None:
track_list, vis, conf = self.track_head(
aggregated_tokens_list, images=images, patch_start_idx=patch_start_idx, query_points=query_points
)
predictions["track"] = track_list[-1] # track of the last iteration
predictions["vis"] = vis
predictions["conf"] = conf
predictions["images"] = images
B, S = images.shape[:2]
ress = []
for s in range(S):
res = {
'pts3d_in_other_view': predictions['world_points'][:, s], # [B, H, W, 3]
'conf': predictions['world_points_conf'][:, s], # [B, H, W]
'depth': predictions['depth'][:, s], # [B, H, W, 1]
'depth_conf': predictions['depth_conf'][:, s], # [B, H, W]
'camera_pose': predictions['pose_enc'][:, s, :], # [B, 9]
**({'valid_mask': views[s]["valid_mask"]}
if 'valid_mask' in views[s] else {}), # [B, H, W]
**({'track': predictions['track'][:, s], # [B, N, 2]
'vis': predictions['vis'][:, s], # [B, N]
'track_conf': predictions['conf'][:, s]}
if 'track' in predictions else {})
}
ress.append(res)
return StreamVGGTOutput(ress=ress, views=views) # [S] [B, C, H, W]
def inference(self, frames, query_points: torch.Tensor = None, past_key_values=None):
past_key_values = [None] * self.aggregator.depth
past_key_values_camera = [None] * self.camera_head.trunk_depth
all_ress = []
processed_frames = []
for i, frame in enumerate(frames):
images = frame["img"].unsqueeze(0)
aggregator_output = self.aggregator(
images,
past_key_values=past_key_values,
use_cache=True,
past_frame_idx=i
)
if isinstance(aggregator_output, tuple) and len(aggregator_output) == 3:
aggregated_tokens, patch_start_idx, past_key_values = aggregator_output
else:
aggregated_tokens, patch_start_idx = aggregator_output
with torch.cuda.amp.autocast(enabled=False):
if self.camera_head is not None:
pose_enc, past_key_values_camera = self.camera_head(aggregated_tokens, past_key_values_camera=past_key_values_camera, use_cache=True)
pose_enc = pose_enc[-1]
camera_pose = pose_enc[:, 0, :]
if self.depth_head is not None:
depth, depth_conf = self.depth_head(
aggregated_tokens, images=images, patch_start_idx=patch_start_idx
)
depth = depth[:, 0]
depth_conf = depth_conf[:, 0]
if self.point_head is not None:
pts3d, pts3d_conf = self.point_head(
aggregated_tokens, images=images, patch_start_idx=patch_start_idx
)
pts3d = pts3d[:, 0]
pts3d_conf = pts3d_conf[:, 0]
if self.track_head is not None and query_points is not None:
track_list, vis, conf = self.track_head(
aggregated_tokens, images=images, patch_start_idx=patch_start_idx, query_points=query_points
)
track = track_list[-1][:, 0]
query_points = track
vis = vis[:, 0]
track_conf = conf[:, 0]
all_ress.append({
'pts3d_in_other_view': pts3d,
'conf': pts3d_conf,
'depth': depth,
'depth_conf': depth_conf,
'camera_pose': camera_pose,
**({'valid_mask': frame["valid_mask"]}
if 'valid_mask' in frame else {}),
**({'track': track,
'vis': vis,
'track_conf': track_conf}
if query_points is not None else {})
})
processed_frames.append(frame)
output = StreamVGGTOutput(ress=all_ress, views=processed_frames)
return output |