embodiedmae-base / modular_embodiedmae.py
ZibinDong's picture
Upload model
4b587c2 verified
from copy import deepcopy
from dataclasses import dataclass
from typing import List, Optional, Tuple
import einops
import numba
import numpy as np
import pytorch3d.ops as torch3d_ops
import pytorch_lightning as L
import torch
import torch.nn as nn
from pytorch3d.loss import chamfer_distance
from transformers import (
AutoModelForMaskedImageModeling,
Dinov2Config,
Dinov2Model,
PretrainedConfig,
PreTrainedModel,
)
from transformers.models.dinov2.modeling_dinov2 import Dinov2Encoder, Dinov2Layer
from transformers.utils import ModelOutput
from .configuration_embodiedmae import EmbodiedMAEConfig
def concat_tensor(
tensors: List[torch.Tensor | None], dim: int = -1, **kwargs
) -> Tuple[torch.Tensor, list]:
filtered_tensors = [t for t in tensors if t is not None]
mask = [(1.0 if t is not None else 0.0) for t in tensors]
return torch.cat(filtered_tensors, dim=dim, **kwargs), mask
def concat_sequence_with_dummy(
tensors: List[torch.Tensor | None], seq_lens: List[int]
) -> torch.Tensor:
"""Concatenate a sequence of tensors. If a tensor is `None`, it will be replaced by a dummy tensor of zeros.
Args:
tensors (List[torch.Tensor | None]):
Tensors to concatenate. If a tensor is `None`, it will be replaced by a dummy tensor of zeros.
seq_lens (List[int]):
Expected sequence length of each tensor.
"""
assert len(tensors) == len(seq_lens)
for t in tensors:
if t is not None:
b, d = t.shape[0], t.shape[2]
device, dtype = t.device, t.dtype
x = []
for t, seq_len in zip(tensors, seq_lens):
if t is None:
x.append(torch.zeros((b, seq_len, d), dtype=dtype, device=device))
else:
x.append(t)
return torch.cat(x, dim=1)
def patchify(pixel_values, patch_size, num_channels, interpolate_pos_encoding: bool = False):
"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values.
interpolate_pos_encoding (`bool`, *optional*, default `False`):
interpolation flag passed during the forward pass.
Returns:
`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
"""
# sanity checks
if not interpolate_pos_encoding and (
pixel_values.shape[2] != pixel_values.shape[3] or pixel_values.shape[2] % patch_size != 0
):
raise ValueError(
"Make sure the pixel values have a squared size that is divisible by the patch size"
)
if pixel_values.shape[1] != num_channels:
raise ValueError(
"Make sure the number of channels of the pixel values is equal to the one set in the configuration"
)
# patchify
batch_size = pixel_values.shape[0]
num_patches_h = pixel_values.shape[2] // patch_size
num_patches_w = pixel_values.shape[3] // patch_size
patchified_pixel_values = pixel_values.reshape(
batch_size, num_channels, num_patches_h, patch_size, num_patches_w, patch_size
)
patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values)
patchified_pixel_values = patchified_pixel_values.reshape(
batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels
)
return patchified_pixel_values
class CrossAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
super().__init__()
self.num_heads = num_heads
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = attn_drop
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, context):
q = self.q(x)
q = einops.rearrange(q, "b t (h d) -> b h t d", h=self.num_heads)
kv = self.kv(context)
kv = einops.rearrange(kv, "b t (h d) -> b h t d", h=self.num_heads)
k, v = torch.chunk(kv, 2, dim=-1)
attn_drop = self.attn_drop if self.training else 0.0
x = nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
x = einops.rearrange(x, "b h t d -> b t (h d)")
x = self.proj(x)
x = self.proj_drop(x)
return x
def unpatchify(patchified_pixel_values, patch_size, num_channels, original_image_size):
"""
Args:
patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
original_image_size (`Tuple[int, int]`, *optional*):
Original image size.
Returns:
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
Pixel values.
"""
original_height, original_width = original_image_size
num_patches_h = original_height // patch_size
num_patches_w = original_width // patch_size
# sanity check
if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]:
raise ValueError(
f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}"
)
# unpatchify
batch_size = patchified_pixel_values.shape[0]
patchified_pixel_values = patchified_pixel_values.reshape(
batch_size,
num_patches_h,
num_patches_w,
patch_size,
patch_size,
num_channels,
)
patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
pixel_values = patchified_pixel_values.reshape(
batch_size,
num_channels,
num_patches_h * patch_size,
num_patches_w * patch_size,
)
return pixel_values
@numba.jit(nopython=True)
def get_mm_shuffle_indices(p, embedding_sz, unmask_sz=128):
b = p.shape[0]
n_modals = len(embedding_sz)
embedding_sz = np.array(embedding_sz)
indices = np.empty((b, embedding_sz.sum()), dtype=np.int64)
for i in numba.prange(b):
um_sz = np.round(p[i] * unmask_sz).astype(np.int64)
um_sz[-1] = unmask_sz - um_sz[:-1].sum()
m_sz = embedding_sz - um_sz
cm_um_sz = np.cumsum(um_sz)
cm_m_sz = np.cumsum(m_sz)
for j in range(n_modals):
shuffle_idx = np.argsort(np.random.random(embedding_sz[j])) + embedding_sz[:j].sum()
um = shuffle_idx[: um_sz[j]]
m = shuffle_idx[um_sz[j] :]
if j == 0:
indices[i, : cm_um_sz[j]] = um
indices[i, unmask_sz : cm_m_sz[j] + unmask_sz] = m
else:
indices[i, cm_um_sz[j - 1] : cm_um_sz[j]] = um
indices[i, cm_m_sz[j - 1] + unmask_sz : cm_m_sz[j] + unmask_sz] = m
return indices
def prepare_shuffle_idx(
has_rgb: bool,
has_depth: bool,
has_pc: bool,
batch_size: int,
unmask_sz: int,
dirichlet: torch.distributions.Dirichlet,
embedding_sz: Tuple[int, int, int],
# rgb: Optional[torch.Tensor],
# depth: Optional[torch.Tensor],
# pc: Optional[torch.Tensor],
add_mask: bool = True,
shuffle_idx: Optional[torch.Tensor] = None,
device: Optional[torch.device] = "cuda",
):
"""Prepare shuffle indices for the input embeddings.
Args:
rgb (Optional[torch.Tensor]):
RGB image from [-1, 1] range, shape (B, C, H, W).
depth (Optional[torch.Tensor]):
Depth map from [0, 2] range, shape (B, C, H, W).
pc (Optional[torch.Tensor]):
Point cloud data, shape (B, N, 3), where N is the number of points.
add_mask (bool, optional):
Whether to add a mask for masked autoencoding. Defaults to True.
unmask_sz (Optional[int], optional):
Size of the unmasked tokens. If None, it will be set to self.unmask_sz. Defaults to None.
shuffle_idx (Optional[torch.Tensor], optional):
Shuffle indices for the input embeddings. If provided, it will be used to restore the original order.
Returns:
_type_: _description_
"""
# provide at least one modality
if not any([has_rgb, has_depth, has_pc]):
raise ValueError("provide at least one modality")
b = batch_size
if add_mask:
if shuffle_idx is not None:
restore_idx = torch.argsort(shuffle_idx, 1)
else:
mask = [float(each) for each in [has_rgb, has_depth, has_pc]]
# multi-modal shuffle
if sum(mask) > 1:
p = dirichlet.sample((b,)).numpy()
p = p * np.array(mask)[None]
p = p / p.sum(-1, keepdims=True)
shuffle_idx = get_mm_shuffle_indices(p, embedding_sz, unmask_sz)
# uni-modal shuffle
else:
shuffle_idx = get_shuffle_indices(embedding_sz[mask.index(1.0)])
restore_idx = np.argsort(shuffle_idx, 1)
shuffle_idx = torch.tensor(shuffle_idx, device=device)
restore_idx = torch.tensor(restore_idx, device=device)
else:
# the missing modality is regarded as masked
unmask_parts, mask_parts = [], []
cumsum_emb_sz = np.cumsum(embedding_sz)
for i, has_modal in enumerate([has_rgb, has_depth, has_pc]):
indices = torch.arange(
cumsum_emb_sz[i - 1] if i > 0 else 0,
cumsum_emb_sz[i],
device=device,
)
if has_modal:
unmask_parts.append(indices)
else:
mask_parts.append(indices)
shuffle_idx = torch.cat(unmask_parts + mask_parts, dim=0)[None].repeat(b, 1)
restore_idx = torch.argsort(shuffle_idx, 1)
unmask_sz = sum([len(part) for part in unmask_parts])
return shuffle_idx, restore_idx, unmask_sz
@numba.jit(nopython=True)
def get_shuffle_indices(embedding_sz):
shuffle_idx = np.argsort(np.random.random(embedding_sz))
return shuffle_idx
def torch_int(x):
import torch
return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
def fps_and_knn(x: torch.Tensor, num_centers: int, num_knn: int):
dtype = x.dtype
x = x.to(torch.float32)
centers, _ = torch3d_ops.sample_farthest_points(x, K=num_centers) # (b, num_centers, 3)
knn_points = torch3d_ops.knn_points(
centers, x, K=num_knn, return_nn=True
).knn # (b, num_centers, knn, 3)
return centers.to(dtype), knn_points.to(dtype)
def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
"""
Create 2D sin/cos positional embeddings.
Args:
embed_dim (`int`):
Embedding dimension.
grid_size (`int`):
The grid height and width.
add_cls_token (`bool`, *optional*, defaults to `False`):
Whether or not to add a classification (CLS) token.
Returns:
(`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the
position embeddings (with or without classification token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if add_cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be even")
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be even")
omega = np.arange(embed_dim // 2, dtype=float)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
@dataclass
class EncoderModelOutput(ModelOutput):
embedding: torch.Tensor = None
pc_centers: torch.Tensor = None
pc_knn: torch.Tensor = None
shuffle_idx: torch.Tensor = None
restore_idx: torch.Tensor = None
last_hidden_states: Optional[torch.Tensor] = None
add_mask: bool = None
hidden_states: Optional[torch.Tensor] = None
attentions: Optional[Tuple[torch.Tensor]] = None
unmask_sz: int = None
@dataclass
class DecoderInput(ModelOutput):
rgb_embedding: torch.Tensor = None
depth_embedding: torch.Tensor = None
pc_embedding: torch.Tensor = None
unmasked_emb: torch.Tensor = None
shuffle_idx: torch.Tensor = None
pc_centers: torch.Tensor = None
pc_knn: torch.Tensor = None
add_mask: bool = None
unmask_sz: int = None
class SharedMlp(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, out_dim),
nn.LayerNorm(out_dim),
nn.GELU(approximate="tanh"),
)
def forward(self, x: torch.Tensor):
return self.net(x)
class MaxPool(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, x: torch.Tensor):
return x.max(self.dim)[0]
class PointGroupEmbedding(nn.Module):
def __init__(self, point_dim: int, d_model: int):
super().__init__()
self.net = nn.Sequential(
SharedMlp(point_dim, 64),
SharedMlp(64, 128),
SharedMlp(128, 256),
MaxPool(-2),
nn.Linear(256, d_model),
)
def forward(self, x: torch.Tensor):
return self.net(x)
class Conv2dPatchify(nn.Module):
def __init__(
self,
patch_size: int = 14,
hidden_size: int = 768,
num_channels: int = 3,
):
super().__init__()
self.num_channels = num_channels
self.patchify = nn.Conv2d(
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
num_channels = pixel_values.shape[-3]
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
f" Expected {self.num_channels} but got {num_channels}."
)
embeddings = self.patchify(pixel_values).flatten(2).transpose(1, 2)
return embeddings
class PatchEmbeddings(nn.Module):
def __init__(
self,
image_size: int = 224,
patch_size: int = 14,
hidden_size: int = 768,
num_channels: int = 3,
dropout: float = 0.0,
):
super().__init__()
self.num_channels = num_channels
self.embeddings = Conv2dPatchify(patch_size, hidden_size, num_channels)
# Use learnable positional embeddings initialized at sin-cos
pos_emb = get_2d_sincos_pos_embed(hidden_size, image_size // patch_size)
pos_emb = torch.tensor(pos_emb, dtype=torch.float32)[None]
self.position_embeddings = nn.Parameter(pos_emb)
self.dropout = nn.Dropout(dropout)
def interpolate_pos_encoding(
self, embeddings: torch.Tensor, height: int, width: int
) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing and interpolation at torch.float32 precision.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""
num_patches = embeddings.shape[1]
num_positions = self.position_embeddings.shape[1]
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
target_dtype = patch_pos_embed.dtype
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.to(torch.float32),
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
).to(dtype=target_dtype)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
def forward(self, pixel_values: Optional[torch.Tensor]) -> torch.Tensor:
if pixel_values is None:
return None
batch_size, _, height, width = pixel_values.shape
target_dtype = self.embeddings.patchify.weight.dtype
embeddings = self.embeddings(pixel_values.to(dtype=target_dtype))
# add positional encoding to each token
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
embeddings = self.dropout(embeddings)
return embeddings
class EmbodiedMAERGBEmbeddings(PatchEmbeddings):
def __init__(self, config: EmbodiedMAEConfig):
super().__init__(
image_size=config.image_size,
patch_size=config.patch_size,
hidden_size=config.hidden_size,
num_channels=3,
dropout=0.0,
)
class EmbodiedMAEDepthEmbeddings(PatchEmbeddings):
def __init__(self, config: EmbodiedMAEConfig):
super().__init__(
image_size=config.image_size,
patch_size=config.patch_size,
hidden_size=config.hidden_size,
num_channels=1,
dropout=0.0,
)
class EmbodiedMAEPointCloudEmbeddings(nn.Module):
def __init__(self, config: EmbodiedMAEConfig):
super().__init__()
self.num_centers, self.num_knn = config.num_pc_centers, config.num_pc_knn
self.knn_embeddings = PointGroupEmbedding(3, config.hidden_size)
self.center_embeddings = nn.Sequential(
nn.Linear(3, config.hidden_size),
nn.GELU(approximate="tanh"),
nn.Linear(config.hidden_size, config.hidden_size),
)
def forward(self, point_cloud: Optional[torch.Tensor]) -> torch.Tensor:
if point_cloud is None:
return None, None, None
centers, knn_points = fps_and_knn(
point_cloud, num_centers=self.num_centers, num_knn=self.num_knn
)
normed_knn_points = knn_points - centers.unsqueeze(-2)
center_emb = self.center_embeddings(centers)
knn_emb = self.knn_embeddings(normed_knn_points)
return center_emb + knn_emb, centers, normed_knn_points
# class EmbodiedMAEModel(nn.Module):
# def __init__(self, config: EmbodiedMAEConfig):
# super().__init__()
# self.config = config
# self.dirichlet = torch.distributions.Dirichlet(torch.full((3,), config.dirichlet_alpha))
# # self.dirichlets = [
# # torch.distributions.Dirichlet(torch.full((i,), config.dirichlet_alpha))
# # for i in range(1, 3)
# # ]
# self.rgb_embeddings = EmbodiedMAERGBEmbeddings(config)
# self.depth_embeddings = EmbodiedMAEDepthEmbeddings(config)
# self.pc_embeddings = EmbodiedMAEPointCloudEmbeddings(config)
# # backbone: Dinov2Model = Dinov2Model.from_pretrained(config.backbone)
# self.encoder = Dinov2Encoder(config)
# # self.encoder.load_state_dict(backbone.encoder.state_dict())
# self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# num_patches = (config.image_size // config.patch_size) ** 2
# self.embedding_sz = (
# num_patches,
# num_patches,
# config.num_pc_centers,
# ) # token size for each modality
# self.unmask_sz = config.unmask_sz # number of unmasked tokens
# # def prepare_shuffle_idx(
# # self,
# # rgb: Optional[torch.Tensor],
# # depth: Optional[torch.Tensor],
# # pc: Optional[torch.Tensor],
# # add_mask: bool = True,
# # unmask_sz: Optional[int] = None,
# # shuffle_idx: Optional[torch.Tensor] = None,
# # ):
# # """Prepare shuffle indices for the input embeddings.
# # Args:
# # rgb (Optional[torch.Tensor]):
# # RGB image from [-1, 1] range, shape (B, C, H, W).
# # depth (Optional[torch.Tensor]):
# # Depth map from [0, 2] range, shape (B, C, H, W).
# # pc (Optional[torch.Tensor]):
# # Point cloud data, shape (B, N, 3), where N is the number of points.
# # add_mask (bool, optional):
# # Whether to add a mask for masked autoencoding. Defaults to True.
# # unmask_sz (Optional[int], optional):
# # Size of the unmasked tokens. If None, it will be set to self.unmask_sz. Defaults to None.
# # shuffle_idx (Optional[torch.Tensor], optional):
# # Shuffle indices for the input embeddings. If provided, it will be used to restore the original order.
# # Returns:
# # _type_: _description_
# # """
# # # provide at least one modality
# # for modal in (rgb, depth, pc):
# # if modal is not None:
# # b = modal.shape[0]
# # device = modal.device
# # break
# # else:
# # raise ValueError("provide at least one modality")
# # if add_mask:
# # unmask_sz = self.unmask_sz if unmask_sz is None else unmask_sz
# # if shuffle_idx is not None:
# # restore_idx = torch.argsort(shuffle_idx, 1)
# # else:
# # mask = [1.0 if t is not None else 0.0 for t in [rgb, depth, pc]]
# # # multi-modal shuffle
# # if sum(mask) > 1:
# # p = self.dirichlet.sample((b,)).numpy()
# # p = p * np.array(mask)[None]
# # p = p / p.sum(-1, keepdims=True)
# # shuffle_idx = get_mm_shuffle_indices(p, self.embedding_sz, unmask_sz)
# # # uni-modal shuffle
# # else:
# # shuffle_idx = get_shuffle_indices(self.embedding_sz[mask.index(1.0)])
# # restore_idx = np.argsort(shuffle_idx, 1)
# # shuffle_idx = torch.tensor(shuffle_idx, device=device)
# # restore_idx = torch.tensor(restore_idx, device=device)
# # else:
# # # the missing modality is regarded as masked
# # unmask_parts, mask_parts = [], []
# # cumsum_emb_sz = np.cumsum(self.embedding_sz)
# # for i, modal in enumerate([rgb, depth, pc]):
# # indices = torch.arange(
# # cumsum_emb_sz[i - 1] if i > 0 else 0,
# # cumsum_emb_sz[i],
# # device=device,
# # )
# # if modal is not None:
# # unmask_parts.append(indices)
# # else:
# # mask_parts.append(indices)
# # shuffle_idx = torch.cat(unmask_parts + mask_parts, dim=0)[None].repeat(b, 1)
# # restore_idx = torch.argsort(shuffle_idx, 1)
# # unmask_sz = sum([len(part) for part in unmask_parts])
# # return shuffle_idx, restore_idx, unmask_sz
# def get_input_embeddings(
# self,
# rgb: Optional[torch.Tensor],
# depth: Optional[torch.Tensor],
# pc: Optional[torch.Tensor],
# add_mask: bool = True,
# unmask_sz: Optional[int] = None,
# forward_pc: bool = True,
# shuffle_idx: Optional[torch.Tensor] = None,
# ):
# # provide at least one modality
# assert any([rgb is not None, depth is not None, pc is not None])
# # embeddings
# rgb_emb = self.rgb_embeddings(rgb)
# depth_emb = self.depth_embeddings(depth)
# pc_emb, pc_centers, pc_knn = self.pc_embeddings(pc)
# if not forward_pc:
# pc = None
# pc_emb = None
# # concat embeddings
# all_emb = concat_sequence_with_dummy([rgb_emb, depth_emb, pc_emb], self.embedding_sz)
# # prepare shuffle indices
# shuffle_idx, restore_idx, unmask_sz = prepare_shuffle_idx(
# has_rgb=rgb is not None,
# has_depth=depth is not None,
# has_pc=pc is not None,
# batch_size=all_emb.shape[0],
# unmask_sz=self.unmask_sz if unmask_sz is None else unmask_sz,
# dirichlet=self.dirichlet,
# embedding_sz=self.embedding_sz,
# add_mask=add_mask,
# shuffle_idx=shuffle_idx,
# device=all_emb.device,
# )
# # get unmasked embeddings
# unmasked_emb = torch.gather(
# all_emb, 1, shuffle_idx[:, :unmask_sz, None].repeat(1, 1, all_emb.shape[-1])
# )
# return EncoderModelOutput(
# embedding=unmasked_emb,
# pc_centers=pc_centers,
# pc_knn=pc_knn,
# shuffle_idx=shuffle_idx,
# restore_idx=restore_idx,
# add_mask=add_mask,
# unmask_sz=unmask_sz,
# )
# # def get_input_embeddings_with_manual_mask(
# # self,
# # rgb: Optional[torch.Tensor],
# # depth: Optional[torch.Tensor],
# # pc: Optional[torch.Tensor],
# # shuffle_idx: torch.Tensor,
# # unmask_sz: int,
# # forward_pc: bool = True,
# # ):
# # # provide at least one modality
# # assert any([rgb is not None, depth is not None, pc is not None])
# # # embeddings
# # rgb_emb = self.rgb_embeddings(rgb)
# # depth_emb = self.depth_embeddings(depth)
# # pc_emb, pc_centers, pc_knn = self.pc_embeddings(pc)
# # if not forward_pc:
# # pc = None
# # pc_emb = None
# # # concat embeddings
# # all_emb = concat_sequence_with_dummy([rgb_emb, depth_emb, pc_emb], self.embedding_sz)
# # shuffle_idx = shuffle_idx.to(all_emb.device)
# # restore_idx = torch.argsort(shuffle_idx, 1)
# # unmasked_emb = torch.gather(
# # all_emb, 1, shuffle_idx[:, :unmask_sz, None].repeat(1, 1, all_emb.shape[-1])
# # )
# # return EncoderModelOutput(
# # embedding=unmasked_emb,
# # pc_centers=pc_centers,
# # pc_knn=pc_knn,
# # shuffle_idx=shuffle_idx,
# # restore_idx=restore_idx,
# # add_mask=None,
# # unmask_sz=unmask_sz,
# # )
# def get_last_hidden_states(
# self,
# embedding_output: EncoderModelOutput,
# output_attentions: bool = False,
# output_hidden_states: bool = False,
# ):
# embedding = embedding_output.embedding
# encoder_outputs = self.encoder(
# embedding,
# output_attentions=output_attentions,
# output_hidden_states=output_hidden_states,
# )
# sequence_output = encoder_outputs[0]
# sequence_output = self.layernorm(sequence_output)
# embedding_output.last_hidden_states = sequence_output
# embedding_output.hidden_states = encoder_outputs.hidden_states
# embedding_output.attentions = encoder_outputs.attentions
# return embedding_output
# def get_decoder_input(self, encoder_output: EncoderModelOutput):
# unmasked_emb = encoder_output.last_hidden_states
# unmask_sz = encoder_output.unmask_sz
# # if encoder_output.add_mask:
# masked_emb = torch.zeros(
# (
# unmasked_emb.shape[0],
# sum(self.embedding_sz) - unmask_sz,
# unmasked_emb.shape[-1],
# ),
# device=unmasked_emb.device,
# dtype=unmasked_emb.dtype,
# )
# all_emb = torch.cat([unmasked_emb, masked_emb], dim=1)
# all_emb = torch.gather(
# all_emb,
# 1,
# encoder_output.restore_idx.unsqueeze(-1).repeat(1, 1, all_emb.shape[-1]),
# )
# # else:
# # all_emb = unmasked_emb
# rgb_emb, depth_emb, pc_emb = torch.split(all_emb, self.embedding_sz, dim=1)
# return DecoderInput(
# rgb_embedding=rgb_emb,
# depth_embedding=depth_emb,
# pc_embedding=pc_emb,
# unmasked_emb=unmasked_emb,
# shuffle_idx=encoder_output.shuffle_idx,
# pc_centers=encoder_output.pc_centers,
# pc_knn=encoder_output.pc_knn,
# add_mask=encoder_output.add_mask,
# unmask_sz=unmask_sz,
# )
# def forward(
# self,
# rgb: Optional[torch.Tensor],
# depth: Optional[torch.Tensor],
# pc: Optional[torch.Tensor],
# add_mask: bool = True,
# unmask_sz: Optional[int] = None,
# output_attentions: bool = False,
# output_hidden_states: bool = False,
# forward_pc: bool = True,
# ):
# embedding_output = self.get_input_embeddings(
# rgb, depth, pc, add_mask, unmask_sz, forward_pc
# )
# return self.get_last_hidden_states(
# embedding_output, output_attentions, output_hidden_states
# )
class EmbodiedMAEDecoder(nn.Module):
def __init__(self, config: EmbodiedMAEConfig):
super().__init__()
image_size = config.image_size
patch_size = config.patch_size
self.config = config
pos_emb = get_2d_sincos_pos_embed(config.decoder_hidden_size, image_size // patch_size)
self.rgb_pos_embed = nn.Parameter(torch.tensor(pos_emb)[None])
self.depth_pos_embed = nn.Parameter(torch.tensor(pos_emb)[None])
self.pc_pos_embed = nn.Sequential(
nn.Linear(3, config.decoder_hidden_size),
nn.GELU(approximate="tanh"),
nn.Linear(config.decoder_hidden_size, config.decoder_hidden_size),
)
num_patches = (config.image_size // config.patch_size) ** 2
self.embedding_sz = (num_patches, num_patches, config.num_pc_centers)
self.unmask_sz = config.unmask_sz
self.context_pos_emb = nn.Parameter(
torch.randn(sum(self.embedding_sz), config.decoder_hidden_size)
)
nn.init.trunc_normal_(self.context_pos_emb, std=config.initializer_range)
self.rgb_query_proj = nn.Linear(config.hidden_size, config.decoder_hidden_size)
self.depth_query_proj = nn.Linear(config.hidden_size, config.decoder_hidden_size)
self.pc_query_proj = nn.Linear(config.hidden_size, config.decoder_hidden_size)
self.rgb_query_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
self.depth_query_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
self.pc_query_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
self.context_proj = nn.Linear(config.hidden_size, config.decoder_hidden_size)
self.context_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
self.rgb_cross_attn = CrossAttention(config.decoder_hidden_size)
self.depth_cross_attn = CrossAttention(config.decoder_hidden_size)
self.pc_cross_attn = CrossAttention(config.decoder_hidden_size)
dec_config = deepcopy(config)
dec_config.hidden_size = config.decoder_hidden_size
dec_config.num_hidden_layers = config.decoder_num_hidden_layers
dec_config.num_attention_heads = config.decoder_num_attention_heads
self.rgb_layer = nn.ModuleList(
[Dinov2Layer(dec_config) for _ in range(dec_config.num_hidden_layers)]
)
self.depth_layer = nn.ModuleList(
[Dinov2Layer(dec_config) for _ in range(dec_config.num_hidden_layers)]
)
self.pc_layer = nn.ModuleList(
[Dinov2Layer(dec_config) for _ in range(dec_config.num_hidden_layers)]
)
self.rgb_out_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
self.depth_out_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
self.pc_out_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
self.rgb_out_proj = nn.Linear(config.decoder_hidden_size, config.patch_size**2 * 3)
self.depth_out_proj = nn.Linear(config.decoder_hidden_size, config.patch_size**2)
self.pc_out_proj = nn.Linear(config.decoder_hidden_size, config.num_pc_knn * 3)
self.norm_pix_loss = config.norm_pix_loss
def get_decoder_input(self, encoder_output: EncoderModelOutput):
"""Convert the encoder output to decoder input."""
unmasked_emb = encoder_output.last_hidden_states
unmask_sz = encoder_output.unmask_sz
masked_emb = torch.zeros(
(
unmasked_emb.shape[0],
sum(self.embedding_sz) - unmask_sz,
unmasked_emb.shape[-1],
),
device=unmasked_emb.device,
dtype=unmasked_emb.dtype,
)
all_emb = torch.cat([unmasked_emb, masked_emb], dim=1)
all_emb = torch.gather(
all_emb,
1,
encoder_output.restore_idx.unsqueeze(-1).repeat(1, 1, all_emb.shape[-1]),
)
rgb_emb, depth_emb, pc_emb = torch.split(all_emb, self.embedding_sz, dim=1)
return DecoderInput(
rgb_embedding=rgb_emb,
depth_embedding=depth_emb,
pc_embedding=pc_emb,
unmasked_emb=unmasked_emb,
shuffle_idx=encoder_output.shuffle_idx,
pc_centers=encoder_output.pc_centers,
pc_knn=encoder_output.pc_knn,
add_mask=encoder_output.add_mask,
unmask_sz=unmask_sz,
)
def forward(self, decoder_input: DecoderInput):
unmask_sz = decoder_input.unmask_sz if decoder_input.unmask_sz else self.unmask_sz
rgb_query = self.rgb_query_proj(decoder_input.rgb_embedding)
depth_query = self.depth_query_proj(decoder_input.depth_embedding)
pc_query = self.pc_query_proj(decoder_input.pc_embedding)
rgb_query = self.rgb_query_norm(rgb_query + self.rgb_pos_embed)
depth_query = self.depth_query_norm(depth_query + self.depth_pos_embed)
if decoder_input.pc_centers is not None:
pc_pos_embed = self.pc_pos_embed(decoder_input.pc_centers)
else:
pc_pos_embed = 0
pc_query = self.pc_query_norm(pc_query + pc_pos_embed)
context = self.context_proj(decoder_input.unmasked_emb)
shuffle_idx = decoder_input.shuffle_idx[:, :unmask_sz]
context_pos_emb = self.context_pos_emb[shuffle_idx]
context = self.context_norm(context + context_pos_emb)
rgb_emb = self.rgb_cross_attn(rgb_query, context)
depth_emb = self.depth_cross_attn(depth_query, context)
pc_emb = self.pc_cross_attn(pc_query, context)
for layers in self.rgb_layer:
rgb_emb = layers(rgb_emb)[0]
for layers in self.depth_layer:
depth_emb = layers(depth_emb)[0]
for layers in self.pc_layer:
pc_emb = layers(pc_emb)[0]
rgb_emb = self.rgb_out_norm(rgb_emb)
depth_emb = self.depth_out_norm(depth_emb)
pc_emb = self.pc_out_norm(pc_emb)
rgb_out = self.rgb_out_proj(rgb_emb)
depth_out = self.depth_out_proj(depth_emb)
pc_out = self.pc_out_proj(pc_emb)
return rgb_out, depth_out, pc_out
def get_loss(self, decoder_input: DecoderInput, rgb, depth, pc):
unmask_sz = decoder_input.unmask_sz
b = rgb.shape[0]
rgb_out, depth_out, pc_out = self(decoder_input)
target_rgb, target_depth = (
patchify(rgb, self.config.patch_size, 3),
patchify(depth, self.config.patch_size, 1),
)
target_pc = decoder_input.pc_knn * 10.0 # meters to centimeters
if self.norm_pix_loss:
rgb_mean, rgb_std = (
target_rgb.mean(-1, keepdim=True),
target_rgb.std(-1, keepdim=True),
)
depth_mean, depth_std = (
target_depth.mean(-1, keepdim=True),
target_depth.std(-1, keepdim=True),
)
else:
rgb_mean, rgb_std = 0.0, 1.0
depth_mean, depth_std = 0.0, 1.0
target_rgb = (target_rgb - rgb_mean) / (rgb_std + 1e-8)
target_depth = (target_depth - depth_mean) / (depth_std + 1e-8)
mask = torch.ones((b, sum(self.embedding_sz)), device=rgb.device)
mask[
torch.arange(b, device=rgb.device)[:, None],
decoder_input.shuffle_idx[:, :unmask_sz],
] = 0
rgb_mask, depth_mask, pc_mask = torch.split(mask, self.embedding_sz, dim=1)
rgb_loss = ((rgb_out - target_rgb).pow(2).mean(-1) * rgb_mask).sum() / rgb_mask.sum()
depth_loss = (
(depth_out - target_depth).abs().mean(-1) * depth_mask
).sum() / depth_mask.sum()
pred_pc = einops.rearrange(pc_out[pc_mask.bool()], "b (k n) -> b k n", n=3)
target_pc = target_pc[pc_mask.bool()]
pc_loss = chamfer_distance(pred_pc.float(), target_pc.float(), norm=1)[0]
return rgb_loss, depth_loss, pc_loss
@torch.no_grad()
def visualize(
self, decoder_input: DecoderInput, rgb: torch.Tensor, depth: torch.Tensor, pc: torch.Tensor
):
"""Visualize the predictions of the decoder.
Args:
decoder_input (DecoderInput):
`decoder_input` from `get_decoder_input`.
rgb (torch.Tensor):
RGB image with shape (B, 3, H, W) in [-1, 1] range.
depth (torch.Tensor):
Depth map with shape (B, 1, H, W) in [0, inf] range. Unit is meters.
pc (torch.Tensor):
Point cloud with shape (B, N, 3), where N=8192 is the number of points. Unit is meters.
Returns:
_type_: _description_
"""
rgb_out, depth_out, pc_out = self(decoder_input)
pc_centers = decoder_input.pc_centers
pc_out = einops.rearrange(pc_out, "... (k n) -> ... k n", n=3)
plt_pc = pc_out / 10.0 + pc_centers.unsqueeze(-2)
b = rgb_out.shape[0]
unmask_sz = decoder_input.unmask_sz
target_rgb, target_depth = (
patchify(rgb, self.config.patch_size, 3),
patchify(depth, self.config.patch_size, 1),
)
if self.norm_pix_loss:
rgb_mean, rgb_std = (
target_rgb.mean(-1, keepdim=True),
target_rgb.std(-1, keepdim=True),
)
depth_mean, depth_std = (
target_depth.mean(-1, keepdim=True),
target_depth.std(-1, keepdim=True),
)
else:
rgb_mean, rgb_std = 0.0, 1.0
depth_mean, depth_std = 0.0, 1.0
pred_rgb = rgb_out * (rgb_std + 1e-8) + rgb_mean
pred_depth = depth_out * (depth_std + 1e-8) + depth_mean
mask = torch.ones((b, sum(self.embedding_sz)), device=rgb.device)
if decoder_input.add_mask:
mask[
torch.arange(b, device=rgb.device)[:, None],
decoder_input.shuffle_idx[:, :unmask_sz],
] = 0
rgb_mask, depth_mask, _ = torch.split(mask, self.embedding_sz, dim=1)
masked_rgb = torch.ones_like(target_rgb) - 2.0
masked_rgb[~rgb_mask.bool()] = target_rgb[~rgb_mask.bool()].to(masked_rgb.dtype)
masked_rgb = unpatchify(
masked_rgb,
self.config.patch_size,
3,
(self.config.image_size, self.config.image_size),
)
pred_rgb[~rgb_mask.bool()] = target_rgb[~rgb_mask.bool()].to(pred_rgb.dtype)
pred_rgb = unpatchify(
pred_rgb,
self.config.patch_size,
3,
(self.config.image_size, self.config.image_size),
)
masked_depth = torch.zeros_like(pred_depth)
masked_depth[~depth_mask.bool()] = target_depth[~depth_mask.bool()].to(masked_depth.dtype)
masked_depth = unpatchify(
masked_depth,
self.config.patch_size,
1,
(self.config.image_size, self.config.image_size),
)
pred_depth[~depth_mask.bool()] = target_depth[~depth_mask.bool()].to(pred_depth.dtype)
pred_depth = unpatchify(
pred_depth,
self.config.patch_size,
1,
(self.config.image_size, self.config.image_size),
)
plt_rgb = (
torch.cat([rgb.float(), masked_rgb.float(), pred_rgb.float()], 2) * 0.5 + 0.5
).clip(0, 1)
plt_depth = (
torch.cat([depth.float(), masked_depth.float(), pred_depth.float()], 2) / 2.0
).clip(0, 1)
return (
plt_rgb.permute(0, 2, 3, 1).cpu(),
plt_depth.permute(0, 2, 3, 1).cpu(),
plt_pc.cpu(),
)
# @torch.no_grad()
# def visualize_pc(self, decoder_input: DecoderInput, rgb, depth, pc):
# rgb_out, depth_out, pc_out = self(decoder_input)
# pc_centers = decoder_input.pc_centers
# pc_out = einops.rearrange(pc_out, "... (k n) -> ... k n", n=3)
# plt_pc = pc_out / 10.0 + pc_centers.unsqueeze(-2)
# b = rgb_out.shape[0]
# target_rgb, target_depth = (
# patchify(rgb, self.config.patch_size, 3),
# patchify(depth, self.config.patch_size, 1),
# )
# if self.norm_pix_loss:
# rgb_mean, rgb_std = (
# target_rgb.mean(-1, keepdim=True),
# target_rgb.std(-1, keepdim=True),
# )
# depth_mean, depth_std = (
# target_depth.mean(-1, keepdim=True),
# target_depth.std(-1, keepdim=True),
# )
# else:
# rgb_mean, rgb_std = 0.0, 1.0
# depth_mean, depth_std = 0.0, 1.0
# pred_rgb = rgb_out * (rgb_std + 1e-8) + rgb_mean
# pred_depth = depth_out * (depth_std + 1e-8) + depth_mean
# mask = torch.ones((b, sum(self.embedding_sz)), device=rgb.device)
# if decoder_input.add_mask:
# mask[
# torch.arange(b, device=rgb.device)[:, None],
# decoder_input.shuffle_idx[:, : self.unmask_sz],
# ] = 0
# rgb_mask, depth_mask, _ = torch.split(mask, self.embedding_sz, dim=1)
# masked_rgb = torch.ones_like(target_rgb) - 2.0
# masked_rgb[~rgb_mask.bool()] = target_rgb[~rgb_mask.bool()].to(masked_rgb.dtype)
# masked_rgb = unpatchify(
# masked_rgb,
# self.config.patch_size,
# 3,
# (self.config.image_size, self.config.image_size),
# )
# pred_rgb[~rgb_mask.bool()] = target_rgb[~rgb_mask.bool()].to(pred_rgb.dtype)
# pred_rgb = unpatchify(
# pred_rgb,
# self.config.patch_size,
# 3,
# (self.config.image_size, self.config.image_size),
# )
# masked_depth = torch.zeros_like(pred_depth)
# masked_depth[~depth_mask.bool()] = target_depth[~depth_mask.bool()].to(masked_depth.dtype)
# masked_depth = unpatchify(
# masked_depth,
# self.config.patch_size,
# 1,
# (self.config.image_size, self.config.image_size),
# )
# pred_depth[~depth_mask.bool()] = target_depth[~depth_mask.bool()].to(pred_depth.dtype)
# pred_depth = unpatchify(
# pred_depth,
# self.config.patch_size,
# 1,
# (self.config.image_size, self.config.image_size),
# )
# plt_rgb = (
# torch.cat([rgb.float(), masked_rgb.float(), pred_rgb.float()], 2) * 0.5 + 0.5
# ).clip(0, 1)
# plt_depth = (
# torch.cat([depth.float(), masked_depth.float(), pred_depth.float()], 2) / 2.0
# ).clip(0, 1)
# return (
# plt_rgb.permute(0, 2, 3, 1).cpu(),
# plt_depth.permute(0, 2, 3, 1).cpu(),
# plt_pc.cpu(),
# )