ZibinDong commited on
Commit
4b587c2
·
verified ·
1 Parent(s): 654502f

Upload model

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "apply_layernorm": true,
3
+ "architectures": [
4
+ "EmbodiedMAEModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_embodiedmae.EmbodiedMAEConfig",
9
+ "AutoModel": "modeling_embodiedmae.EmbodiedMAEModel"
10
+ },
11
+ "decoder_hidden_size": 512,
12
+ "decoder_num_attention_heads": 8,
13
+ "decoder_num_hidden_layers": 4,
14
+ "dirichlet_alpha": 1.0,
15
+ "drop_path_rate": 0.0,
16
+ "hidden_act": "gelu",
17
+ "hidden_dropout_prob": 0.0,
18
+ "hidden_size": 768,
19
+ "image_size": 224,
20
+ "initializer_range": 0.02,
21
+ "layer_norm_eps": 1e-06,
22
+ "layerscale_value": 1.0,
23
+ "mlp_ratio": 4,
24
+ "model_type": "EmbodiedMAE",
25
+ "norm_pix_loss": false,
26
+ "num_attention_heads": 12,
27
+ "num_hidden_layers": 12,
28
+ "num_pc_centers": 196,
29
+ "num_pc_knn": 64,
30
+ "patch_size": 16,
31
+ "qkv_bias": true,
32
+ "torch_dtype": "float32",
33
+ "transformers_version": "4.48.0",
34
+ "unmask_sz": 98,
35
+ "use_swiglu_ffn": false
36
+ }
configuration_embodiedmae.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from transformers.models.dinov2.configuration_dinov2 import Dinov2Config
3
+
4
+
5
+ class EmbodiedMAEConfig(PretrainedConfig):
6
+ model_type = "EmbodiedMAE"
7
+
8
+ def __init__(
9
+ self,
10
+ hidden_size: int = 768,
11
+ num_hidden_layers: int = 12,
12
+ num_attention_heads: int = 12,
13
+ mlp_ratio: int = 4,
14
+ hidden_dropout_prob: float = 0.0,
15
+ attention_probs_dropout_prob: float = 0.0,
16
+ initializer_range: float = 0.02,
17
+ qkv_bias: bool = True,
18
+ apply_layernorm: bool = True,
19
+ attn_implementation: str = "eager",
20
+ layerscale_value: float = 1.0,
21
+ drop_path_rate: float = 0.0,
22
+ layer_norm_eps: float = 1e-6,
23
+ hidden_act: str = "gelu",
24
+ use_swiglu_ffn: bool = False,
25
+ image_size: int = 224,
26
+ patch_size: int = 16,
27
+ num_pc_centers: int = 196,
28
+ num_pc_knn: int = 64,
29
+ dirichlet_alpha: int = 1.0,
30
+ unmask_sz: int = 98,
31
+ decoder_hidden_size: int = 512,
32
+ decoder_num_hidden_layers: int = 4,
33
+ decoder_num_attention_heads: int = 8,
34
+ norm_pix_loss: int = False,
35
+ **kwargs,
36
+ ):
37
+ self.hidden_size = hidden_size
38
+ self.num_hidden_layers = num_hidden_layers
39
+ self.num_attention_heads = num_attention_heads
40
+ self.mlp_ratio = mlp_ratio
41
+ self.hidden_dropout_prob = hidden_dropout_prob
42
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
43
+ self.initializer_range = initializer_range
44
+
45
+ self.image_size = image_size
46
+ self.patch_size = patch_size
47
+ self.qkv_bias = qkv_bias
48
+ self.apply_layernorm = apply_layernorm
49
+ self.num_pc_centers = num_pc_centers
50
+ self.num_pc_knn = num_pc_knn
51
+ self.dirichlet_alpha = dirichlet_alpha
52
+ self.unmask_sz = unmask_sz
53
+
54
+ self._attn_implementation = attn_implementation
55
+ self.layerscale_value = layerscale_value
56
+ self.drop_path_rate = drop_path_rate
57
+ self.layer_norm_eps = layer_norm_eps
58
+ self.hidden_act = hidden_act
59
+ self.use_swiglu_ffn = use_swiglu_ffn
60
+
61
+ self.decoder_hidden_size = decoder_hidden_size
62
+ self.decoder_num_hidden_layers = decoder_num_hidden_layers
63
+ self.decoder_num_attention_heads = decoder_num_attention_heads
64
+
65
+ self.norm_pix_loss = norm_pix_loss
66
+
67
+ super().__init__(**kwargs)
68
+
69
+
70
+ BACKBONE_KWARGS = {
71
+ "hidden_size",
72
+ "num_hidden_layers",
73
+ "num_attention_heads",
74
+ "mlp_ratio",
75
+ "hidden_dropout_prob",
76
+ "attention_probs_dropout_prob",
77
+ "initializer_range",
78
+ "qkv_bias",
79
+ "apply_layernorm",
80
+ "attn_implementation",
81
+ "layerscale_value",
82
+ "drop_path_rate",
83
+ "layer_norm_eps",
84
+ "hidden_act",
85
+ "use_swiglu_ffn",
86
+ }
87
+
88
+
89
+ # get config for different size
90
+ def get_embodied_mae_config(size: str = "base") -> EmbodiedMAEConfig:
91
+ backbone_config = Dinov2Config.from_pretrained(f"facebook/dinov2-{size}")
92
+ kwargs = {k: v for k, v in backbone_config.to_dict().items() if k in BACKBONE_KWARGS}
93
+ norm_pix_loss = True if size == "giant" else False
94
+ return EmbodiedMAEConfig(**kwargs, norm_pix_loss=norm_pix_loss)
95
+
96
+
97
+ __all__ = [EmbodiedMAEConfig, get_embodied_mae_config]
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5de63a523109b61f8b10b4dee84f36bf06e3427e2359081d891ed8f2f9603749
3
+ size 348014176
modeling_embodiedmae.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import PreTrainedModel
6
+ from transformers.models.dinov2.modeling_dinov2 import Dinov2Encoder
7
+
8
+ from .configuration_embodiedmae import EmbodiedMAEConfig
9
+ from .modular_embodiedmae import (
10
+ EmbodiedMAEDecoder,
11
+ EmbodiedMAEDepthEmbeddings,
12
+ EmbodiedMAEPointCloudEmbeddings,
13
+ EmbodiedMAERGBEmbeddings,
14
+ EncoderModelOutput,
15
+ concat_sequence_with_dummy,
16
+ prepare_shuffle_idx,
17
+ )
18
+
19
+
20
+ class EmbodiedMAEModel(PreTrainedModel):
21
+ config_class = EmbodiedMAEConfig
22
+
23
+ def __init__(self, config: EmbodiedMAEConfig):
24
+ super().__init__(config)
25
+ self.config = config
26
+
27
+ self.dirichlet = torch.distributions.Dirichlet(torch.full((3,), config.dirichlet_alpha))
28
+
29
+ self.rgb_embeddings = EmbodiedMAERGBEmbeddings(config)
30
+ self.depth_embeddings = EmbodiedMAEDepthEmbeddings(config)
31
+ self.pc_embeddings = EmbodiedMAEPointCloudEmbeddings(config)
32
+
33
+ self.encoder = Dinov2Encoder(config)
34
+
35
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
36
+
37
+ num_patches = (config.image_size // config.patch_size) ** 2
38
+ self.embedding_sz = (
39
+ num_patches,
40
+ num_patches,
41
+ config.num_pc_centers,
42
+ ) # token size for each modality
43
+ self.unmask_sz = config.unmask_sz # number of unmasked tokens
44
+
45
+ def get_input_embeddings(
46
+ self,
47
+ rgb: Optional[torch.Tensor],
48
+ depth: Optional[torch.Tensor],
49
+ pc: Optional[torch.Tensor],
50
+ add_mask: bool = True,
51
+ unmask_sz: Optional[int] = None,
52
+ forward_pc: bool = True,
53
+ shuffle_idx: Optional[torch.Tensor] = None,
54
+ ):
55
+ # provide at least one modality
56
+ assert any([rgb is not None, depth is not None, pc is not None])
57
+
58
+ # embeddings
59
+ rgb_emb = self.rgb_embeddings(rgb)
60
+ depth_emb = self.depth_embeddings(depth)
61
+ pc_emb, pc_centers, pc_knn = self.pc_embeddings(pc)
62
+ if not forward_pc:
63
+ pc = None
64
+ pc_emb = None
65
+
66
+ # concat embeddings
67
+ all_emb = concat_sequence_with_dummy([rgb_emb, depth_emb, pc_emb], self.embedding_sz)
68
+
69
+ # prepare shuffle indices
70
+ shuffle_idx, restore_idx, unmask_sz = prepare_shuffle_idx(
71
+ has_rgb=rgb is not None,
72
+ has_depth=depth is not None,
73
+ has_pc=pc is not None,
74
+ batch_size=all_emb.shape[0],
75
+ unmask_sz=self.unmask_sz if unmask_sz is None else unmask_sz,
76
+ dirichlet=self.dirichlet,
77
+ embedding_sz=self.embedding_sz,
78
+ add_mask=add_mask,
79
+ shuffle_idx=shuffle_idx,
80
+ device=all_emb.device,
81
+ )
82
+
83
+ # get unmasked embeddings
84
+ unmasked_emb = torch.gather(
85
+ all_emb, 1, shuffle_idx[:, :unmask_sz, None].repeat(1, 1, all_emb.shape[-1])
86
+ )
87
+
88
+ return EncoderModelOutput(
89
+ embedding=unmasked_emb,
90
+ pc_centers=pc_centers,
91
+ pc_knn=pc_knn,
92
+ shuffle_idx=shuffle_idx,
93
+ restore_idx=restore_idx,
94
+ add_mask=add_mask,
95
+ unmask_sz=unmask_sz,
96
+ )
97
+
98
+ def get_last_hidden_states(
99
+ self,
100
+ embedding_output: EncoderModelOutput,
101
+ output_attentions: bool = False,
102
+ output_hidden_states: bool = False,
103
+ ):
104
+ embedding = embedding_output.embedding
105
+
106
+ encoder_outputs = self.encoder(
107
+ embedding,
108
+ output_attentions=output_attentions,
109
+ output_hidden_states=output_hidden_states,
110
+ )
111
+ sequence_output = encoder_outputs[0]
112
+ sequence_output = self.layernorm(sequence_output)
113
+
114
+ embedding_output.last_hidden_states = sequence_output
115
+ embedding_output.hidden_states = encoder_outputs.hidden_states
116
+ embedding_output.attentions = encoder_outputs.attentions
117
+
118
+ return embedding_output
119
+
120
+ def forward(
121
+ self,
122
+ rgb: Optional[torch.Tensor],
123
+ depth: Optional[torch.Tensor],
124
+ pc: Optional[torch.Tensor],
125
+ add_mask: bool = True,
126
+ unmask_sz: Optional[int] = None,
127
+ output_attentions: bool = False,
128
+ output_hidden_states: bool = False,
129
+ forward_pc: bool = True,
130
+ ):
131
+ embedding_output = self.get_input_embeddings(
132
+ rgb, depth, pc, add_mask, unmask_sz, forward_pc
133
+ )
134
+ return self.get_last_hidden_states(
135
+ embedding_output, output_attentions, output_hidden_states
136
+ )
137
+
138
+
139
+ class EmbodiedMAEForMaskedImageModeling(EmbodiedMAEModel):
140
+ def __init__(self, config: EmbodiedMAEConfig):
141
+ super().__init__(config)
142
+ self.decoder = EmbodiedMAEDecoder(config)
143
+
144
+ def forward(
145
+ self,
146
+ rgb: Optional[torch.Tensor],
147
+ depth: Optional[torch.Tensor],
148
+ pc: Optional[torch.Tensor],
149
+ add_mask: bool = True,
150
+ unmask_sz: Optional[int] = None,
151
+ output_attentions: bool = False,
152
+ output_hidden_states: bool = False,
153
+ forward_pc: bool = True,
154
+ ):
155
+ encoder_output = super().forward(
156
+ rgb, depth, pc, add_mask, unmask_sz, output_attentions, output_hidden_states, forward_pc
157
+ )
158
+ decoder_input = self.decoder.get_decoder_input(encoder_output)
159
+ return self.decoder(decoder_input)
160
+
161
+ @torch.no_grad()
162
+ def visualize(
163
+ self,
164
+ rgb: Optional[torch.Tensor],
165
+ depth: Optional[torch.Tensor],
166
+ pc: Optional[torch.Tensor],
167
+ mask_rgb: bool = False,
168
+ mask_depth: bool = False,
169
+ mask_pc: bool = False,
170
+ add_mask: bool = True,
171
+ unmask_sz: Optional[int] = None,
172
+ output_attentions: bool = False,
173
+ output_hidden_states: bool = False,
174
+ forward_pc: bool = True,
175
+ ):
176
+ _rgb = None if mask_rgb else rgb
177
+ _depth = None if mask_depth else depth
178
+ _pc = None if mask_pc else pc
179
+ encoder_output = super().forward(
180
+ _rgb,
181
+ _depth,
182
+ _pc,
183
+ add_mask,
184
+ unmask_sz,
185
+ output_attentions,
186
+ output_hidden_states,
187
+ forward_pc,
188
+ )
189
+ decoder_input = self.decoder.get_decoder_input(encoder_output)
190
+ return self.decoder.visualize(decoder_input, rgb, depth, pc)
191
+
192
+
193
+ __all__ = [EmbodiedMAEModel, EmbodiedMAEForMaskedImageModeling]
modular_embodiedmae.py ADDED
@@ -0,0 +1,1195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple
4
+
5
+ import einops
6
+ import numba
7
+ import numpy as np
8
+ import pytorch3d.ops as torch3d_ops
9
+ import pytorch_lightning as L
10
+ import torch
11
+ import torch.nn as nn
12
+ from pytorch3d.loss import chamfer_distance
13
+ from transformers import (
14
+ AutoModelForMaskedImageModeling,
15
+ Dinov2Config,
16
+ Dinov2Model,
17
+ PretrainedConfig,
18
+ PreTrainedModel,
19
+ )
20
+ from transformers.models.dinov2.modeling_dinov2 import Dinov2Encoder, Dinov2Layer
21
+ from transformers.utils import ModelOutput
22
+
23
+ from .configuration_embodiedmae import EmbodiedMAEConfig
24
+
25
+
26
+ def concat_tensor(
27
+ tensors: List[torch.Tensor | None], dim: int = -1, **kwargs
28
+ ) -> Tuple[torch.Tensor, list]:
29
+ filtered_tensors = [t for t in tensors if t is not None]
30
+ mask = [(1.0 if t is not None else 0.0) for t in tensors]
31
+ return torch.cat(filtered_tensors, dim=dim, **kwargs), mask
32
+
33
+
34
+ def concat_sequence_with_dummy(
35
+ tensors: List[torch.Tensor | None], seq_lens: List[int]
36
+ ) -> torch.Tensor:
37
+ """Concatenate a sequence of tensors. If a tensor is `None`, it will be replaced by a dummy tensor of zeros.
38
+
39
+ Args:
40
+ tensors (List[torch.Tensor | None]):
41
+ Tensors to concatenate. If a tensor is `None`, it will be replaced by a dummy tensor of zeros.
42
+ seq_lens (List[int]):
43
+ Expected sequence length of each tensor.
44
+ """
45
+ assert len(tensors) == len(seq_lens)
46
+ for t in tensors:
47
+ if t is not None:
48
+ b, d = t.shape[0], t.shape[2]
49
+ device, dtype = t.device, t.dtype
50
+ x = []
51
+ for t, seq_len in zip(tensors, seq_lens):
52
+ if t is None:
53
+ x.append(torch.zeros((b, seq_len, d), dtype=dtype, device=device))
54
+ else:
55
+ x.append(t)
56
+ return torch.cat(x, dim=1)
57
+
58
+
59
+ def patchify(pixel_values, patch_size, num_channels, interpolate_pos_encoding: bool = False):
60
+ """
61
+ Args:
62
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
63
+ Pixel values.
64
+ interpolate_pos_encoding (`bool`, *optional*, default `False`):
65
+ interpolation flag passed during the forward pass.
66
+
67
+ Returns:
68
+ `torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
69
+ Patchified pixel values.
70
+ """
71
+ # sanity checks
72
+ if not interpolate_pos_encoding and (
73
+ pixel_values.shape[2] != pixel_values.shape[3] or pixel_values.shape[2] % patch_size != 0
74
+ ):
75
+ raise ValueError(
76
+ "Make sure the pixel values have a squared size that is divisible by the patch size"
77
+ )
78
+ if pixel_values.shape[1] != num_channels:
79
+ raise ValueError(
80
+ "Make sure the number of channels of the pixel values is equal to the one set in the configuration"
81
+ )
82
+
83
+ # patchify
84
+ batch_size = pixel_values.shape[0]
85
+ num_patches_h = pixel_values.shape[2] // patch_size
86
+ num_patches_w = pixel_values.shape[3] // patch_size
87
+ patchified_pixel_values = pixel_values.reshape(
88
+ batch_size, num_channels, num_patches_h, patch_size, num_patches_w, patch_size
89
+ )
90
+ patchified_pixel_values = torch.einsum("nchpwq->nhwpqc", patchified_pixel_values)
91
+ patchified_pixel_values = patchified_pixel_values.reshape(
92
+ batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels
93
+ )
94
+ return patchified_pixel_values
95
+
96
+
97
+ class CrossAttention(nn.Module):
98
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
99
+ super().__init__()
100
+ self.num_heads = num_heads
101
+
102
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
103
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
104
+
105
+ self.attn_drop = attn_drop
106
+ self.proj = nn.Linear(dim, dim)
107
+ self.proj_drop = nn.Dropout(proj_drop)
108
+
109
+ def forward(self, x, context):
110
+ q = self.q(x)
111
+ q = einops.rearrange(q, "b t (h d) -> b h t d", h=self.num_heads)
112
+ kv = self.kv(context)
113
+ kv = einops.rearrange(kv, "b t (h d) -> b h t d", h=self.num_heads)
114
+ k, v = torch.chunk(kv, 2, dim=-1)
115
+
116
+ attn_drop = self.attn_drop if self.training else 0.0
117
+ x = nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=attn_drop)
118
+ x = einops.rearrange(x, "b h t d -> b t (h d)")
119
+ x = self.proj(x)
120
+ x = self.proj_drop(x)
121
+ return x
122
+
123
+
124
+ def unpatchify(patchified_pixel_values, patch_size, num_channels, original_image_size):
125
+ """
126
+ Args:
127
+ patchified_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
128
+ Patchified pixel values.
129
+ original_image_size (`Tuple[int, int]`, *optional*):
130
+ Original image size.
131
+
132
+ Returns:
133
+ `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
134
+ Pixel values.
135
+ """
136
+ original_height, original_width = original_image_size
137
+ num_patches_h = original_height // patch_size
138
+ num_patches_w = original_width // patch_size
139
+ # sanity check
140
+ if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]:
141
+ raise ValueError(
142
+ f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}"
143
+ )
144
+
145
+ # unpatchify
146
+ batch_size = patchified_pixel_values.shape[0]
147
+ patchified_pixel_values = patchified_pixel_values.reshape(
148
+ batch_size,
149
+ num_patches_h,
150
+ num_patches_w,
151
+ patch_size,
152
+ patch_size,
153
+ num_channels,
154
+ )
155
+ patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
156
+ pixel_values = patchified_pixel_values.reshape(
157
+ batch_size,
158
+ num_channels,
159
+ num_patches_h * patch_size,
160
+ num_patches_w * patch_size,
161
+ )
162
+ return pixel_values
163
+
164
+
165
+ @numba.jit(nopython=True)
166
+ def get_mm_shuffle_indices(p, embedding_sz, unmask_sz=128):
167
+ b = p.shape[0]
168
+ n_modals = len(embedding_sz)
169
+ embedding_sz = np.array(embedding_sz)
170
+ indices = np.empty((b, embedding_sz.sum()), dtype=np.int64)
171
+
172
+ for i in numba.prange(b):
173
+ um_sz = np.round(p[i] * unmask_sz).astype(np.int64)
174
+ um_sz[-1] = unmask_sz - um_sz[:-1].sum()
175
+ m_sz = embedding_sz - um_sz
176
+ cm_um_sz = np.cumsum(um_sz)
177
+ cm_m_sz = np.cumsum(m_sz)
178
+
179
+ for j in range(n_modals):
180
+ shuffle_idx = np.argsort(np.random.random(embedding_sz[j])) + embedding_sz[:j].sum()
181
+ um = shuffle_idx[: um_sz[j]]
182
+ m = shuffle_idx[um_sz[j] :]
183
+
184
+ if j == 0:
185
+ indices[i, : cm_um_sz[j]] = um
186
+ indices[i, unmask_sz : cm_m_sz[j] + unmask_sz] = m
187
+ else:
188
+ indices[i, cm_um_sz[j - 1] : cm_um_sz[j]] = um
189
+ indices[i, cm_m_sz[j - 1] + unmask_sz : cm_m_sz[j] + unmask_sz] = m
190
+ return indices
191
+
192
+
193
+ def prepare_shuffle_idx(
194
+ has_rgb: bool,
195
+ has_depth: bool,
196
+ has_pc: bool,
197
+ batch_size: int,
198
+ unmask_sz: int,
199
+ dirichlet: torch.distributions.Dirichlet,
200
+ embedding_sz: Tuple[int, int, int],
201
+ # rgb: Optional[torch.Tensor],
202
+ # depth: Optional[torch.Tensor],
203
+ # pc: Optional[torch.Tensor],
204
+ add_mask: bool = True,
205
+ shuffle_idx: Optional[torch.Tensor] = None,
206
+ device: Optional[torch.device] = "cuda",
207
+ ):
208
+ """Prepare shuffle indices for the input embeddings.
209
+
210
+ Args:
211
+ rgb (Optional[torch.Tensor]):
212
+ RGB image from [-1, 1] range, shape (B, C, H, W).
213
+ depth (Optional[torch.Tensor]):
214
+ Depth map from [0, 2] range, shape (B, C, H, W).
215
+ pc (Optional[torch.Tensor]):
216
+ Point cloud data, shape (B, N, 3), where N is the number of points.
217
+ add_mask (bool, optional):
218
+ Whether to add a mask for masked autoencoding. Defaults to True.
219
+ unmask_sz (Optional[int], optional):
220
+ Size of the unmasked tokens. If None, it will be set to self.unmask_sz. Defaults to None.
221
+ shuffle_idx (Optional[torch.Tensor], optional):
222
+ Shuffle indices for the input embeddings. If provided, it will be used to restore the original order.
223
+
224
+ Returns:
225
+ _type_: _description_
226
+ """
227
+ # provide at least one modality
228
+ if not any([has_rgb, has_depth, has_pc]):
229
+ raise ValueError("provide at least one modality")
230
+
231
+ b = batch_size
232
+
233
+ if add_mask:
234
+ if shuffle_idx is not None:
235
+ restore_idx = torch.argsort(shuffle_idx, 1)
236
+ else:
237
+ mask = [float(each) for each in [has_rgb, has_depth, has_pc]]
238
+ # multi-modal shuffle
239
+ if sum(mask) > 1:
240
+ p = dirichlet.sample((b,)).numpy()
241
+ p = p * np.array(mask)[None]
242
+ p = p / p.sum(-1, keepdims=True)
243
+ shuffle_idx = get_mm_shuffle_indices(p, embedding_sz, unmask_sz)
244
+ # uni-modal shuffle
245
+ else:
246
+ shuffle_idx = get_shuffle_indices(embedding_sz[mask.index(1.0)])
247
+ restore_idx = np.argsort(shuffle_idx, 1)
248
+ shuffle_idx = torch.tensor(shuffle_idx, device=device)
249
+ restore_idx = torch.tensor(restore_idx, device=device)
250
+ else:
251
+ # the missing modality is regarded as masked
252
+ unmask_parts, mask_parts = [], []
253
+ cumsum_emb_sz = np.cumsum(embedding_sz)
254
+ for i, has_modal in enumerate([has_rgb, has_depth, has_pc]):
255
+ indices = torch.arange(
256
+ cumsum_emb_sz[i - 1] if i > 0 else 0,
257
+ cumsum_emb_sz[i],
258
+ device=device,
259
+ )
260
+ if has_modal:
261
+ unmask_parts.append(indices)
262
+ else:
263
+ mask_parts.append(indices)
264
+ shuffle_idx = torch.cat(unmask_parts + mask_parts, dim=0)[None].repeat(b, 1)
265
+ restore_idx = torch.argsort(shuffle_idx, 1)
266
+ unmask_sz = sum([len(part) for part in unmask_parts])
267
+
268
+ return shuffle_idx, restore_idx, unmask_sz
269
+
270
+
271
+ @numba.jit(nopython=True)
272
+ def get_shuffle_indices(embedding_sz):
273
+ shuffle_idx = np.argsort(np.random.random(embedding_sz))
274
+ return shuffle_idx
275
+
276
+
277
+ def torch_int(x):
278
+ import torch
279
+
280
+ return x.to(torch.int64) if torch.jit.is_tracing() and isinstance(x, torch.Tensor) else int(x)
281
+
282
+
283
+ def fps_and_knn(x: torch.Tensor, num_centers: int, num_knn: int):
284
+ dtype = x.dtype
285
+ x = x.to(torch.float32)
286
+ centers, _ = torch3d_ops.sample_farthest_points(x, K=num_centers) # (b, num_centers, 3)
287
+ knn_points = torch3d_ops.knn_points(
288
+ centers, x, K=num_knn, return_nn=True
289
+ ).knn # (b, num_centers, knn, 3)
290
+ return centers.to(dtype), knn_points.to(dtype)
291
+
292
+
293
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
294
+ """
295
+ Create 2D sin/cos positional embeddings.
296
+
297
+ Args:
298
+ embed_dim (`int`):
299
+ Embedding dimension.
300
+ grid_size (`int`):
301
+ The grid height and width.
302
+ add_cls_token (`bool`, *optional*, defaults to `False`):
303
+ Whether or not to add a classification (CLS) token.
304
+
305
+ Returns:
306
+ (`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the
307
+ position embeddings (with or without classification token)
308
+ """
309
+ grid_h = np.arange(grid_size, dtype=np.float32)
310
+ grid_w = np.arange(grid_size, dtype=np.float32)
311
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
312
+ grid = np.stack(grid, axis=0)
313
+
314
+ grid = grid.reshape([2, 1, grid_size, grid_size])
315
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
316
+ if add_cls_token:
317
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
318
+ return pos_embed
319
+
320
+
321
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
322
+ if embed_dim % 2 != 0:
323
+ raise ValueError("embed_dim must be even")
324
+
325
+ # use half of dimensions to encode grid_h
326
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
327
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
328
+
329
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
330
+ return emb
331
+
332
+
333
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
334
+ """
335
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
336
+ """
337
+ if embed_dim % 2 != 0:
338
+ raise ValueError("embed_dim must be even")
339
+
340
+ omega = np.arange(embed_dim // 2, dtype=float)
341
+ omega /= embed_dim / 2.0
342
+ omega = 1.0 / 10000**omega # (D/2,)
343
+
344
+ pos = pos.reshape(-1) # (M,)
345
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
346
+
347
+ emb_sin = np.sin(out) # (M, D/2)
348
+ emb_cos = np.cos(out) # (M, D/2)
349
+
350
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
351
+ return emb
352
+
353
+
354
+ @dataclass
355
+ class EncoderModelOutput(ModelOutput):
356
+ embedding: torch.Tensor = None
357
+ pc_centers: torch.Tensor = None
358
+ pc_knn: torch.Tensor = None
359
+ shuffle_idx: torch.Tensor = None
360
+ restore_idx: torch.Tensor = None
361
+ last_hidden_states: Optional[torch.Tensor] = None
362
+ add_mask: bool = None
363
+ hidden_states: Optional[torch.Tensor] = None
364
+ attentions: Optional[Tuple[torch.Tensor]] = None
365
+ unmask_sz: int = None
366
+
367
+
368
+ @dataclass
369
+ class DecoderInput(ModelOutput):
370
+ rgb_embedding: torch.Tensor = None
371
+ depth_embedding: torch.Tensor = None
372
+ pc_embedding: torch.Tensor = None
373
+ unmasked_emb: torch.Tensor = None
374
+ shuffle_idx: torch.Tensor = None
375
+ pc_centers: torch.Tensor = None
376
+ pc_knn: torch.Tensor = None
377
+ add_mask: bool = None
378
+ unmask_sz: int = None
379
+
380
+
381
+ class SharedMlp(nn.Module):
382
+ def __init__(self, in_dim: int, out_dim: int):
383
+ super().__init__()
384
+ self.net = nn.Sequential(
385
+ nn.Linear(in_dim, out_dim),
386
+ nn.LayerNorm(out_dim),
387
+ nn.GELU(approximate="tanh"),
388
+ )
389
+
390
+ def forward(self, x: torch.Tensor):
391
+ return self.net(x)
392
+
393
+
394
+ class MaxPool(nn.Module):
395
+ def __init__(self, dim: int):
396
+ super().__init__()
397
+ self.dim = dim
398
+
399
+ def forward(self, x: torch.Tensor):
400
+ return x.max(self.dim)[0]
401
+
402
+
403
+ class PointGroupEmbedding(nn.Module):
404
+ def __init__(self, point_dim: int, d_model: int):
405
+ super().__init__()
406
+ self.net = nn.Sequential(
407
+ SharedMlp(point_dim, 64),
408
+ SharedMlp(64, 128),
409
+ SharedMlp(128, 256),
410
+ MaxPool(-2),
411
+ nn.Linear(256, d_model),
412
+ )
413
+
414
+ def forward(self, x: torch.Tensor):
415
+ return self.net(x)
416
+
417
+
418
+ class Conv2dPatchify(nn.Module):
419
+ def __init__(
420
+ self,
421
+ patch_size: int = 14,
422
+ hidden_size: int = 768,
423
+ num_channels: int = 3,
424
+ ):
425
+ super().__init__()
426
+ self.num_channels = num_channels
427
+ self.patchify = nn.Conv2d(
428
+ num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
429
+ )
430
+
431
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
432
+ num_channels = pixel_values.shape[-3]
433
+ if num_channels != self.num_channels:
434
+ raise ValueError(
435
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
436
+ f" Expected {self.num_channels} but got {num_channels}."
437
+ )
438
+ embeddings = self.patchify(pixel_values).flatten(2).transpose(1, 2)
439
+ return embeddings
440
+
441
+
442
+ class PatchEmbeddings(nn.Module):
443
+ def __init__(
444
+ self,
445
+ image_size: int = 224,
446
+ patch_size: int = 14,
447
+ hidden_size: int = 768,
448
+ num_channels: int = 3,
449
+ dropout: float = 0.0,
450
+ ):
451
+ super().__init__()
452
+ self.num_channels = num_channels
453
+ self.embeddings = Conv2dPatchify(patch_size, hidden_size, num_channels)
454
+ # Use learnable positional embeddings initialized at sin-cos
455
+ pos_emb = get_2d_sincos_pos_embed(hidden_size, image_size // patch_size)
456
+ pos_emb = torch.tensor(pos_emb, dtype=torch.float32)[None]
457
+ self.position_embeddings = nn.Parameter(pos_emb)
458
+ self.dropout = nn.Dropout(dropout)
459
+
460
+ def interpolate_pos_encoding(
461
+ self, embeddings: torch.Tensor, height: int, width: int
462
+ ) -> torch.Tensor:
463
+ """
464
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
465
+ images. This method is also adapted to support torch.jit tracing and interpolation at torch.float32 precision.
466
+
467
+ Adapted from:
468
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
469
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
470
+ """
471
+
472
+ num_patches = embeddings.shape[1]
473
+ num_positions = self.position_embeddings.shape[1]
474
+
475
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
476
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
477
+ return self.position_embeddings
478
+
479
+ patch_pos_embed = self.position_embeddings[:, 1:]
480
+
481
+ dim = embeddings.shape[-1]
482
+
483
+ new_height = height // self.patch_size
484
+ new_width = width // self.patch_size
485
+
486
+ sqrt_num_positions = torch_int(num_positions**0.5)
487
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
488
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
489
+ target_dtype = patch_pos_embed.dtype
490
+ patch_pos_embed = nn.functional.interpolate(
491
+ patch_pos_embed.to(torch.float32),
492
+ size=(new_height, new_width),
493
+ mode="bicubic",
494
+ align_corners=False,
495
+ ).to(dtype=target_dtype)
496
+
497
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
498
+
499
+ return patch_pos_embed
500
+
501
+ def forward(self, pixel_values: Optional[torch.Tensor]) -> torch.Tensor:
502
+ if pixel_values is None:
503
+ return None
504
+ batch_size, _, height, width = pixel_values.shape
505
+ target_dtype = self.embeddings.patchify.weight.dtype
506
+ embeddings = self.embeddings(pixel_values.to(dtype=target_dtype))
507
+ # add positional encoding to each token
508
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
509
+ embeddings = self.dropout(embeddings)
510
+ return embeddings
511
+
512
+
513
+ class EmbodiedMAERGBEmbeddings(PatchEmbeddings):
514
+ def __init__(self, config: EmbodiedMAEConfig):
515
+ super().__init__(
516
+ image_size=config.image_size,
517
+ patch_size=config.patch_size,
518
+ hidden_size=config.hidden_size,
519
+ num_channels=3,
520
+ dropout=0.0,
521
+ )
522
+
523
+
524
+ class EmbodiedMAEDepthEmbeddings(PatchEmbeddings):
525
+ def __init__(self, config: EmbodiedMAEConfig):
526
+ super().__init__(
527
+ image_size=config.image_size,
528
+ patch_size=config.patch_size,
529
+ hidden_size=config.hidden_size,
530
+ num_channels=1,
531
+ dropout=0.0,
532
+ )
533
+
534
+
535
+ class EmbodiedMAEPointCloudEmbeddings(nn.Module):
536
+ def __init__(self, config: EmbodiedMAEConfig):
537
+ super().__init__()
538
+ self.num_centers, self.num_knn = config.num_pc_centers, config.num_pc_knn
539
+ self.knn_embeddings = PointGroupEmbedding(3, config.hidden_size)
540
+ self.center_embeddings = nn.Sequential(
541
+ nn.Linear(3, config.hidden_size),
542
+ nn.GELU(approximate="tanh"),
543
+ nn.Linear(config.hidden_size, config.hidden_size),
544
+ )
545
+
546
+ def forward(self, point_cloud: Optional[torch.Tensor]) -> torch.Tensor:
547
+ if point_cloud is None:
548
+ return None, None, None
549
+ centers, knn_points = fps_and_knn(
550
+ point_cloud, num_centers=self.num_centers, num_knn=self.num_knn
551
+ )
552
+ normed_knn_points = knn_points - centers.unsqueeze(-2)
553
+ center_emb = self.center_embeddings(centers)
554
+ knn_emb = self.knn_embeddings(normed_knn_points)
555
+ return center_emb + knn_emb, centers, normed_knn_points
556
+
557
+
558
+ # class EmbodiedMAEModel(nn.Module):
559
+ # def __init__(self, config: EmbodiedMAEConfig):
560
+ # super().__init__()
561
+ # self.config = config
562
+
563
+ # self.dirichlet = torch.distributions.Dirichlet(torch.full((3,), config.dirichlet_alpha))
564
+ # # self.dirichlets = [
565
+ # # torch.distributions.Dirichlet(torch.full((i,), config.dirichlet_alpha))
566
+ # # for i in range(1, 3)
567
+ # # ]
568
+
569
+ # self.rgb_embeddings = EmbodiedMAERGBEmbeddings(config)
570
+ # self.depth_embeddings = EmbodiedMAEDepthEmbeddings(config)
571
+ # self.pc_embeddings = EmbodiedMAEPointCloudEmbeddings(config)
572
+
573
+ # # backbone: Dinov2Model = Dinov2Model.from_pretrained(config.backbone)
574
+ # self.encoder = Dinov2Encoder(config)
575
+ # # self.encoder.load_state_dict(backbone.encoder.state_dict())
576
+
577
+ # self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
578
+
579
+ # num_patches = (config.image_size // config.patch_size) ** 2
580
+ # self.embedding_sz = (
581
+ # num_patches,
582
+ # num_patches,
583
+ # config.num_pc_centers,
584
+ # ) # token size for each modality
585
+ # self.unmask_sz = config.unmask_sz # number of unmasked tokens
586
+
587
+ # # def prepare_shuffle_idx(
588
+ # # self,
589
+ # # rgb: Optional[torch.Tensor],
590
+ # # depth: Optional[torch.Tensor],
591
+ # # pc: Optional[torch.Tensor],
592
+ # # add_mask: bool = True,
593
+ # # unmask_sz: Optional[int] = None,
594
+ # # shuffle_idx: Optional[torch.Tensor] = None,
595
+ # # ):
596
+ # # """Prepare shuffle indices for the input embeddings.
597
+
598
+ # # Args:
599
+ # # rgb (Optional[torch.Tensor]):
600
+ # # RGB image from [-1, 1] range, shape (B, C, H, W).
601
+ # # depth (Optional[torch.Tensor]):
602
+ # # Depth map from [0, 2] range, shape (B, C, H, W).
603
+ # # pc (Optional[torch.Tensor]):
604
+ # # Point cloud data, shape (B, N, 3), where N is the number of points.
605
+ # # add_mask (bool, optional):
606
+ # # Whether to add a mask for masked autoencoding. Defaults to True.
607
+ # # unmask_sz (Optional[int], optional):
608
+ # # Size of the unmasked tokens. If None, it will be set to self.unmask_sz. Defaults to None.
609
+ # # shuffle_idx (Optional[torch.Tensor], optional):
610
+ # # Shuffle indices for the input embeddings. If provided, it will be used to restore the original order.
611
+
612
+ # # Returns:
613
+ # # _type_: _description_
614
+ # # """
615
+ # # # provide at least one modality
616
+ # # for modal in (rgb, depth, pc):
617
+ # # if modal is not None:
618
+ # # b = modal.shape[0]
619
+ # # device = modal.device
620
+ # # break
621
+ # # else:
622
+ # # raise ValueError("provide at least one modality")
623
+
624
+ # # if add_mask:
625
+ # # unmask_sz = self.unmask_sz if unmask_sz is None else unmask_sz
626
+ # # if shuffle_idx is not None:
627
+ # # restore_idx = torch.argsort(shuffle_idx, 1)
628
+ # # else:
629
+ # # mask = [1.0 if t is not None else 0.0 for t in [rgb, depth, pc]]
630
+ # # # multi-modal shuffle
631
+ # # if sum(mask) > 1:
632
+ # # p = self.dirichlet.sample((b,)).numpy()
633
+ # # p = p * np.array(mask)[None]
634
+ # # p = p / p.sum(-1, keepdims=True)
635
+ # # shuffle_idx = get_mm_shuffle_indices(p, self.embedding_sz, unmask_sz)
636
+ # # # uni-modal shuffle
637
+ # # else:
638
+ # # shuffle_idx = get_shuffle_indices(self.embedding_sz[mask.index(1.0)])
639
+ # # restore_idx = np.argsort(shuffle_idx, 1)
640
+ # # shuffle_idx = torch.tensor(shuffle_idx, device=device)
641
+ # # restore_idx = torch.tensor(restore_idx, device=device)
642
+ # # else:
643
+ # # # the missing modality is regarded as masked
644
+ # # unmask_parts, mask_parts = [], []
645
+ # # cumsum_emb_sz = np.cumsum(self.embedding_sz)
646
+ # # for i, modal in enumerate([rgb, depth, pc]):
647
+ # # indices = torch.arange(
648
+ # # cumsum_emb_sz[i - 1] if i > 0 else 0,
649
+ # # cumsum_emb_sz[i],
650
+ # # device=device,
651
+ # # )
652
+ # # if modal is not None:
653
+ # # unmask_parts.append(indices)
654
+ # # else:
655
+ # # mask_parts.append(indices)
656
+ # # shuffle_idx = torch.cat(unmask_parts + mask_parts, dim=0)[None].repeat(b, 1)
657
+ # # restore_idx = torch.argsort(shuffle_idx, 1)
658
+ # # unmask_sz = sum([len(part) for part in unmask_parts])
659
+
660
+ # # return shuffle_idx, restore_idx, unmask_sz
661
+
662
+ # def get_input_embeddings(
663
+ # self,
664
+ # rgb: Optional[torch.Tensor],
665
+ # depth: Optional[torch.Tensor],
666
+ # pc: Optional[torch.Tensor],
667
+ # add_mask: bool = True,
668
+ # unmask_sz: Optional[int] = None,
669
+ # forward_pc: bool = True,
670
+ # shuffle_idx: Optional[torch.Tensor] = None,
671
+ # ):
672
+ # # provide at least one modality
673
+ # assert any([rgb is not None, depth is not None, pc is not None])
674
+
675
+ # # embeddings
676
+ # rgb_emb = self.rgb_embeddings(rgb)
677
+ # depth_emb = self.depth_embeddings(depth)
678
+ # pc_emb, pc_centers, pc_knn = self.pc_embeddings(pc)
679
+ # if not forward_pc:
680
+ # pc = None
681
+ # pc_emb = None
682
+
683
+ # # concat embeddings
684
+ # all_emb = concat_sequence_with_dummy([rgb_emb, depth_emb, pc_emb], self.embedding_sz)
685
+
686
+ # # prepare shuffle indices
687
+ # shuffle_idx, restore_idx, unmask_sz = prepare_shuffle_idx(
688
+ # has_rgb=rgb is not None,
689
+ # has_depth=depth is not None,
690
+ # has_pc=pc is not None,
691
+ # batch_size=all_emb.shape[0],
692
+ # unmask_sz=self.unmask_sz if unmask_sz is None else unmask_sz,
693
+ # dirichlet=self.dirichlet,
694
+ # embedding_sz=self.embedding_sz,
695
+ # add_mask=add_mask,
696
+ # shuffle_idx=shuffle_idx,
697
+ # device=all_emb.device,
698
+ # )
699
+
700
+ # # get unmasked embeddings
701
+ # unmasked_emb = torch.gather(
702
+ # all_emb, 1, shuffle_idx[:, :unmask_sz, None].repeat(1, 1, all_emb.shape[-1])
703
+ # )
704
+
705
+ # return EncoderModelOutput(
706
+ # embedding=unmasked_emb,
707
+ # pc_centers=pc_centers,
708
+ # pc_knn=pc_knn,
709
+ # shuffle_idx=shuffle_idx,
710
+ # restore_idx=restore_idx,
711
+ # add_mask=add_mask,
712
+ # unmask_sz=unmask_sz,
713
+ # )
714
+
715
+ # # def get_input_embeddings_with_manual_mask(
716
+ # # self,
717
+ # # rgb: Optional[torch.Tensor],
718
+ # # depth: Optional[torch.Tensor],
719
+ # # pc: Optional[torch.Tensor],
720
+ # # shuffle_idx: torch.Tensor,
721
+ # # unmask_sz: int,
722
+ # # forward_pc: bool = True,
723
+ # # ):
724
+ # # # provide at least one modality
725
+ # # assert any([rgb is not None, depth is not None, pc is not None])
726
+
727
+ # # # embeddings
728
+ # # rgb_emb = self.rgb_embeddings(rgb)
729
+ # # depth_emb = self.depth_embeddings(depth)
730
+ # # pc_emb, pc_centers, pc_knn = self.pc_embeddings(pc)
731
+ # # if not forward_pc:
732
+ # # pc = None
733
+ # # pc_emb = None
734
+
735
+ # # # concat embeddings
736
+ # # all_emb = concat_sequence_with_dummy([rgb_emb, depth_emb, pc_emb], self.embedding_sz)
737
+
738
+ # # shuffle_idx = shuffle_idx.to(all_emb.device)
739
+ # # restore_idx = torch.argsort(shuffle_idx, 1)
740
+
741
+ # # unmasked_emb = torch.gather(
742
+ # # all_emb, 1, shuffle_idx[:, :unmask_sz, None].repeat(1, 1, all_emb.shape[-1])
743
+ # # )
744
+
745
+ # # return EncoderModelOutput(
746
+ # # embedding=unmasked_emb,
747
+ # # pc_centers=pc_centers,
748
+ # # pc_knn=pc_knn,
749
+ # # shuffle_idx=shuffle_idx,
750
+ # # restore_idx=restore_idx,
751
+ # # add_mask=None,
752
+ # # unmask_sz=unmask_sz,
753
+ # # )
754
+
755
+ # def get_last_hidden_states(
756
+ # self,
757
+ # embedding_output: EncoderModelOutput,
758
+ # output_attentions: bool = False,
759
+ # output_hidden_states: bool = False,
760
+ # ):
761
+ # embedding = embedding_output.embedding
762
+
763
+ # encoder_outputs = self.encoder(
764
+ # embedding,
765
+ # output_attentions=output_attentions,
766
+ # output_hidden_states=output_hidden_states,
767
+ # )
768
+ # sequence_output = encoder_outputs[0]
769
+ # sequence_output = self.layernorm(sequence_output)
770
+
771
+ # embedding_output.last_hidden_states = sequence_output
772
+ # embedding_output.hidden_states = encoder_outputs.hidden_states
773
+ # embedding_output.attentions = encoder_outputs.attentions
774
+
775
+ # return embedding_output
776
+
777
+ # def get_decoder_input(self, encoder_output: EncoderModelOutput):
778
+ # unmasked_emb = encoder_output.last_hidden_states
779
+ # unmask_sz = encoder_output.unmask_sz
780
+
781
+ # # if encoder_output.add_mask:
782
+ # masked_emb = torch.zeros(
783
+ # (
784
+ # unmasked_emb.shape[0],
785
+ # sum(self.embedding_sz) - unmask_sz,
786
+ # unmasked_emb.shape[-1],
787
+ # ),
788
+ # device=unmasked_emb.device,
789
+ # dtype=unmasked_emb.dtype,
790
+ # )
791
+ # all_emb = torch.cat([unmasked_emb, masked_emb], dim=1)
792
+ # all_emb = torch.gather(
793
+ # all_emb,
794
+ # 1,
795
+ # encoder_output.restore_idx.unsqueeze(-1).repeat(1, 1, all_emb.shape[-1]),
796
+ # )
797
+ # # else:
798
+ # # all_emb = unmasked_emb
799
+
800
+ # rgb_emb, depth_emb, pc_emb = torch.split(all_emb, self.embedding_sz, dim=1)
801
+
802
+ # return DecoderInput(
803
+ # rgb_embedding=rgb_emb,
804
+ # depth_embedding=depth_emb,
805
+ # pc_embedding=pc_emb,
806
+ # unmasked_emb=unmasked_emb,
807
+ # shuffle_idx=encoder_output.shuffle_idx,
808
+ # pc_centers=encoder_output.pc_centers,
809
+ # pc_knn=encoder_output.pc_knn,
810
+ # add_mask=encoder_output.add_mask,
811
+ # unmask_sz=unmask_sz,
812
+ # )
813
+
814
+ # def forward(
815
+ # self,
816
+ # rgb: Optional[torch.Tensor],
817
+ # depth: Optional[torch.Tensor],
818
+ # pc: Optional[torch.Tensor],
819
+ # add_mask: bool = True,
820
+ # unmask_sz: Optional[int] = None,
821
+ # output_attentions: bool = False,
822
+ # output_hidden_states: bool = False,
823
+ # forward_pc: bool = True,
824
+ # ):
825
+ # embedding_output = self.get_input_embeddings(
826
+ # rgb, depth, pc, add_mask, unmask_sz, forward_pc
827
+ # )
828
+ # return self.get_last_hidden_states(
829
+ # embedding_output, output_attentions, output_hidden_states
830
+ # )
831
+
832
+
833
+ class EmbodiedMAEDecoder(nn.Module):
834
+ def __init__(self, config: EmbodiedMAEConfig):
835
+ super().__init__()
836
+ image_size = config.image_size
837
+ patch_size = config.patch_size
838
+ self.config = config
839
+
840
+ pos_emb = get_2d_sincos_pos_embed(config.decoder_hidden_size, image_size // patch_size)
841
+ self.rgb_pos_embed = nn.Parameter(torch.tensor(pos_emb)[None])
842
+ self.depth_pos_embed = nn.Parameter(torch.tensor(pos_emb)[None])
843
+ self.pc_pos_embed = nn.Sequential(
844
+ nn.Linear(3, config.decoder_hidden_size),
845
+ nn.GELU(approximate="tanh"),
846
+ nn.Linear(config.decoder_hidden_size, config.decoder_hidden_size),
847
+ )
848
+
849
+ num_patches = (config.image_size // config.patch_size) ** 2
850
+ self.embedding_sz = (num_patches, num_patches, config.num_pc_centers)
851
+ self.unmask_sz = config.unmask_sz
852
+ self.context_pos_emb = nn.Parameter(
853
+ torch.randn(sum(self.embedding_sz), config.decoder_hidden_size)
854
+ )
855
+ nn.init.trunc_normal_(self.context_pos_emb, std=config.initializer_range)
856
+
857
+ self.rgb_query_proj = nn.Linear(config.hidden_size, config.decoder_hidden_size)
858
+ self.depth_query_proj = nn.Linear(config.hidden_size, config.decoder_hidden_size)
859
+ self.pc_query_proj = nn.Linear(config.hidden_size, config.decoder_hidden_size)
860
+ self.rgb_query_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
861
+ self.depth_query_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
862
+ self.pc_query_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
863
+
864
+ self.context_proj = nn.Linear(config.hidden_size, config.decoder_hidden_size)
865
+ self.context_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
866
+
867
+ self.rgb_cross_attn = CrossAttention(config.decoder_hidden_size)
868
+ self.depth_cross_attn = CrossAttention(config.decoder_hidden_size)
869
+ self.pc_cross_attn = CrossAttention(config.decoder_hidden_size)
870
+
871
+ dec_config = deepcopy(config)
872
+ dec_config.hidden_size = config.decoder_hidden_size
873
+ dec_config.num_hidden_layers = config.decoder_num_hidden_layers
874
+ dec_config.num_attention_heads = config.decoder_num_attention_heads
875
+
876
+ self.rgb_layer = nn.ModuleList(
877
+ [Dinov2Layer(dec_config) for _ in range(dec_config.num_hidden_layers)]
878
+ )
879
+ self.depth_layer = nn.ModuleList(
880
+ [Dinov2Layer(dec_config) for _ in range(dec_config.num_hidden_layers)]
881
+ )
882
+ self.pc_layer = nn.ModuleList(
883
+ [Dinov2Layer(dec_config) for _ in range(dec_config.num_hidden_layers)]
884
+ )
885
+
886
+ self.rgb_out_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
887
+ self.depth_out_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
888
+ self.pc_out_norm = nn.LayerNorm(config.decoder_hidden_size, eps=config.layer_norm_eps)
889
+
890
+ self.rgb_out_proj = nn.Linear(config.decoder_hidden_size, config.patch_size**2 * 3)
891
+ self.depth_out_proj = nn.Linear(config.decoder_hidden_size, config.patch_size**2)
892
+ self.pc_out_proj = nn.Linear(config.decoder_hidden_size, config.num_pc_knn * 3)
893
+
894
+ self.norm_pix_loss = config.norm_pix_loss
895
+
896
+ def get_decoder_input(self, encoder_output: EncoderModelOutput):
897
+ """Convert the encoder output to decoder input."""
898
+ unmasked_emb = encoder_output.last_hidden_states
899
+ unmask_sz = encoder_output.unmask_sz
900
+
901
+ masked_emb = torch.zeros(
902
+ (
903
+ unmasked_emb.shape[0],
904
+ sum(self.embedding_sz) - unmask_sz,
905
+ unmasked_emb.shape[-1],
906
+ ),
907
+ device=unmasked_emb.device,
908
+ dtype=unmasked_emb.dtype,
909
+ )
910
+ all_emb = torch.cat([unmasked_emb, masked_emb], dim=1)
911
+ all_emb = torch.gather(
912
+ all_emb,
913
+ 1,
914
+ encoder_output.restore_idx.unsqueeze(-1).repeat(1, 1, all_emb.shape[-1]),
915
+ )
916
+ rgb_emb, depth_emb, pc_emb = torch.split(all_emb, self.embedding_sz, dim=1)
917
+
918
+ return DecoderInput(
919
+ rgb_embedding=rgb_emb,
920
+ depth_embedding=depth_emb,
921
+ pc_embedding=pc_emb,
922
+ unmasked_emb=unmasked_emb,
923
+ shuffle_idx=encoder_output.shuffle_idx,
924
+ pc_centers=encoder_output.pc_centers,
925
+ pc_knn=encoder_output.pc_knn,
926
+ add_mask=encoder_output.add_mask,
927
+ unmask_sz=unmask_sz,
928
+ )
929
+
930
+ def forward(self, decoder_input: DecoderInput):
931
+ unmask_sz = decoder_input.unmask_sz if decoder_input.unmask_sz else self.unmask_sz
932
+ rgb_query = self.rgb_query_proj(decoder_input.rgb_embedding)
933
+ depth_query = self.depth_query_proj(decoder_input.depth_embedding)
934
+ pc_query = self.pc_query_proj(decoder_input.pc_embedding)
935
+ rgb_query = self.rgb_query_norm(rgb_query + self.rgb_pos_embed)
936
+ depth_query = self.depth_query_norm(depth_query + self.depth_pos_embed)
937
+ if decoder_input.pc_centers is not None:
938
+ pc_pos_embed = self.pc_pos_embed(decoder_input.pc_centers)
939
+ else:
940
+ pc_pos_embed = 0
941
+ pc_query = self.pc_query_norm(pc_query + pc_pos_embed)
942
+
943
+ context = self.context_proj(decoder_input.unmasked_emb)
944
+ shuffle_idx = decoder_input.shuffle_idx[:, :unmask_sz]
945
+ context_pos_emb = self.context_pos_emb[shuffle_idx]
946
+ context = self.context_norm(context + context_pos_emb)
947
+
948
+ rgb_emb = self.rgb_cross_attn(rgb_query, context)
949
+ depth_emb = self.depth_cross_attn(depth_query, context)
950
+ pc_emb = self.pc_cross_attn(pc_query, context)
951
+
952
+ for layers in self.rgb_layer:
953
+ rgb_emb = layers(rgb_emb)[0]
954
+ for layers in self.depth_layer:
955
+ depth_emb = layers(depth_emb)[0]
956
+ for layers in self.pc_layer:
957
+ pc_emb = layers(pc_emb)[0]
958
+
959
+ rgb_emb = self.rgb_out_norm(rgb_emb)
960
+ depth_emb = self.depth_out_norm(depth_emb)
961
+ pc_emb = self.pc_out_norm(pc_emb)
962
+
963
+ rgb_out = self.rgb_out_proj(rgb_emb)
964
+ depth_out = self.depth_out_proj(depth_emb)
965
+ pc_out = self.pc_out_proj(pc_emb)
966
+
967
+ return rgb_out, depth_out, pc_out
968
+
969
+ def get_loss(self, decoder_input: DecoderInput, rgb, depth, pc):
970
+ unmask_sz = decoder_input.unmask_sz
971
+ b = rgb.shape[0]
972
+ rgb_out, depth_out, pc_out = self(decoder_input)
973
+
974
+ target_rgb, target_depth = (
975
+ patchify(rgb, self.config.patch_size, 3),
976
+ patchify(depth, self.config.patch_size, 1),
977
+ )
978
+ target_pc = decoder_input.pc_knn * 10.0 # meters to centimeters
979
+
980
+ if self.norm_pix_loss:
981
+ rgb_mean, rgb_std = (
982
+ target_rgb.mean(-1, keepdim=True),
983
+ target_rgb.std(-1, keepdim=True),
984
+ )
985
+ depth_mean, depth_std = (
986
+ target_depth.mean(-1, keepdim=True),
987
+ target_depth.std(-1, keepdim=True),
988
+ )
989
+ else:
990
+ rgb_mean, rgb_std = 0.0, 1.0
991
+ depth_mean, depth_std = 0.0, 1.0
992
+
993
+ target_rgb = (target_rgb - rgb_mean) / (rgb_std + 1e-8)
994
+ target_depth = (target_depth - depth_mean) / (depth_std + 1e-8)
995
+
996
+ mask = torch.ones((b, sum(self.embedding_sz)), device=rgb.device)
997
+ mask[
998
+ torch.arange(b, device=rgb.device)[:, None],
999
+ decoder_input.shuffle_idx[:, :unmask_sz],
1000
+ ] = 0
1001
+ rgb_mask, depth_mask, pc_mask = torch.split(mask, self.embedding_sz, dim=1)
1002
+
1003
+ rgb_loss = ((rgb_out - target_rgb).pow(2).mean(-1) * rgb_mask).sum() / rgb_mask.sum()
1004
+ depth_loss = (
1005
+ (depth_out - target_depth).abs().mean(-1) * depth_mask
1006
+ ).sum() / depth_mask.sum()
1007
+
1008
+ pred_pc = einops.rearrange(pc_out[pc_mask.bool()], "b (k n) -> b k n", n=3)
1009
+ target_pc = target_pc[pc_mask.bool()]
1010
+ pc_loss = chamfer_distance(pred_pc.float(), target_pc.float(), norm=1)[0]
1011
+
1012
+ return rgb_loss, depth_loss, pc_loss
1013
+
1014
+ @torch.no_grad()
1015
+ def visualize(
1016
+ self, decoder_input: DecoderInput, rgb: torch.Tensor, depth: torch.Tensor, pc: torch.Tensor
1017
+ ):
1018
+ """Visualize the predictions of the decoder.
1019
+
1020
+ Args:
1021
+ decoder_input (DecoderInput):
1022
+ `decoder_input` from `get_decoder_input`.
1023
+ rgb (torch.Tensor):
1024
+ RGB image with shape (B, 3, H, W) in [-1, 1] range.
1025
+ depth (torch.Tensor):
1026
+ Depth map with shape (B, 1, H, W) in [0, inf] range. Unit is meters.
1027
+ pc (torch.Tensor):
1028
+ Point cloud with shape (B, N, 3), where N=8192 is the number of points. Unit is meters.
1029
+
1030
+ Returns:
1031
+ _type_: _description_
1032
+ """
1033
+ rgb_out, depth_out, pc_out = self(decoder_input)
1034
+ pc_centers = decoder_input.pc_centers
1035
+ pc_out = einops.rearrange(pc_out, "... (k n) -> ... k n", n=3)
1036
+ plt_pc = pc_out / 10.0 + pc_centers.unsqueeze(-2)
1037
+ b = rgb_out.shape[0]
1038
+ unmask_sz = decoder_input.unmask_sz
1039
+
1040
+ target_rgb, target_depth = (
1041
+ patchify(rgb, self.config.patch_size, 3),
1042
+ patchify(depth, self.config.patch_size, 1),
1043
+ )
1044
+
1045
+ if self.norm_pix_loss:
1046
+ rgb_mean, rgb_std = (
1047
+ target_rgb.mean(-1, keepdim=True),
1048
+ target_rgb.std(-1, keepdim=True),
1049
+ )
1050
+ depth_mean, depth_std = (
1051
+ target_depth.mean(-1, keepdim=True),
1052
+ target_depth.std(-1, keepdim=True),
1053
+ )
1054
+ else:
1055
+ rgb_mean, rgb_std = 0.0, 1.0
1056
+ depth_mean, depth_std = 0.0, 1.0
1057
+
1058
+ pred_rgb = rgb_out * (rgb_std + 1e-8) + rgb_mean
1059
+ pred_depth = depth_out * (depth_std + 1e-8) + depth_mean
1060
+
1061
+ mask = torch.ones((b, sum(self.embedding_sz)), device=rgb.device)
1062
+ if decoder_input.add_mask:
1063
+ mask[
1064
+ torch.arange(b, device=rgb.device)[:, None],
1065
+ decoder_input.shuffle_idx[:, :unmask_sz],
1066
+ ] = 0
1067
+ rgb_mask, depth_mask, _ = torch.split(mask, self.embedding_sz, dim=1)
1068
+
1069
+ masked_rgb = torch.ones_like(target_rgb) - 2.0
1070
+ masked_rgb[~rgb_mask.bool()] = target_rgb[~rgb_mask.bool()].to(masked_rgb.dtype)
1071
+ masked_rgb = unpatchify(
1072
+ masked_rgb,
1073
+ self.config.patch_size,
1074
+ 3,
1075
+ (self.config.image_size, self.config.image_size),
1076
+ )
1077
+ pred_rgb[~rgb_mask.bool()] = target_rgb[~rgb_mask.bool()].to(pred_rgb.dtype)
1078
+ pred_rgb = unpatchify(
1079
+ pred_rgb,
1080
+ self.config.patch_size,
1081
+ 3,
1082
+ (self.config.image_size, self.config.image_size),
1083
+ )
1084
+
1085
+ masked_depth = torch.zeros_like(pred_depth)
1086
+ masked_depth[~depth_mask.bool()] = target_depth[~depth_mask.bool()].to(masked_depth.dtype)
1087
+ masked_depth = unpatchify(
1088
+ masked_depth,
1089
+ self.config.patch_size,
1090
+ 1,
1091
+ (self.config.image_size, self.config.image_size),
1092
+ )
1093
+ pred_depth[~depth_mask.bool()] = target_depth[~depth_mask.bool()].to(pred_depth.dtype)
1094
+ pred_depth = unpatchify(
1095
+ pred_depth,
1096
+ self.config.patch_size,
1097
+ 1,
1098
+ (self.config.image_size, self.config.image_size),
1099
+ )
1100
+
1101
+ plt_rgb = (
1102
+ torch.cat([rgb.float(), masked_rgb.float(), pred_rgb.float()], 2) * 0.5 + 0.5
1103
+ ).clip(0, 1)
1104
+ plt_depth = (
1105
+ torch.cat([depth.float(), masked_depth.float(), pred_depth.float()], 2) / 2.0
1106
+ ).clip(0, 1)
1107
+
1108
+ return (
1109
+ plt_rgb.permute(0, 2, 3, 1).cpu(),
1110
+ plt_depth.permute(0, 2, 3, 1).cpu(),
1111
+ plt_pc.cpu(),
1112
+ )
1113
+
1114
+ # @torch.no_grad()
1115
+ # def visualize_pc(self, decoder_input: DecoderInput, rgb, depth, pc):
1116
+ # rgb_out, depth_out, pc_out = self(decoder_input)
1117
+ # pc_centers = decoder_input.pc_centers
1118
+ # pc_out = einops.rearrange(pc_out, "... (k n) -> ... k n", n=3)
1119
+ # plt_pc = pc_out / 10.0 + pc_centers.unsqueeze(-2)
1120
+
1121
+ # b = rgb_out.shape[0]
1122
+
1123
+ # target_rgb, target_depth = (
1124
+ # patchify(rgb, self.config.patch_size, 3),
1125
+ # patchify(depth, self.config.patch_size, 1),
1126
+ # )
1127
+
1128
+ # if self.norm_pix_loss:
1129
+ # rgb_mean, rgb_std = (
1130
+ # target_rgb.mean(-1, keepdim=True),
1131
+ # target_rgb.std(-1, keepdim=True),
1132
+ # )
1133
+ # depth_mean, depth_std = (
1134
+ # target_depth.mean(-1, keepdim=True),
1135
+ # target_depth.std(-1, keepdim=True),
1136
+ # )
1137
+ # else:
1138
+ # rgb_mean, rgb_std = 0.0, 1.0
1139
+ # depth_mean, depth_std = 0.0, 1.0
1140
+
1141
+ # pred_rgb = rgb_out * (rgb_std + 1e-8) + rgb_mean
1142
+ # pred_depth = depth_out * (depth_std + 1e-8) + depth_mean
1143
+
1144
+ # mask = torch.ones((b, sum(self.embedding_sz)), device=rgb.device)
1145
+ # if decoder_input.add_mask:
1146
+ # mask[
1147
+ # torch.arange(b, device=rgb.device)[:, None],
1148
+ # decoder_input.shuffle_idx[:, : self.unmask_sz],
1149
+ # ] = 0
1150
+ # rgb_mask, depth_mask, _ = torch.split(mask, self.embedding_sz, dim=1)
1151
+
1152
+ # masked_rgb = torch.ones_like(target_rgb) - 2.0
1153
+ # masked_rgb[~rgb_mask.bool()] = target_rgb[~rgb_mask.bool()].to(masked_rgb.dtype)
1154
+ # masked_rgb = unpatchify(
1155
+ # masked_rgb,
1156
+ # self.config.patch_size,
1157
+ # 3,
1158
+ # (self.config.image_size, self.config.image_size),
1159
+ # )
1160
+ # pred_rgb[~rgb_mask.bool()] = target_rgb[~rgb_mask.bool()].to(pred_rgb.dtype)
1161
+ # pred_rgb = unpatchify(
1162
+ # pred_rgb,
1163
+ # self.config.patch_size,
1164
+ # 3,
1165
+ # (self.config.image_size, self.config.image_size),
1166
+ # )
1167
+
1168
+ # masked_depth = torch.zeros_like(pred_depth)
1169
+ # masked_depth[~depth_mask.bool()] = target_depth[~depth_mask.bool()].to(masked_depth.dtype)
1170
+ # masked_depth = unpatchify(
1171
+ # masked_depth,
1172
+ # self.config.patch_size,
1173
+ # 1,
1174
+ # (self.config.image_size, self.config.image_size),
1175
+ # )
1176
+ # pred_depth[~depth_mask.bool()] = target_depth[~depth_mask.bool()].to(pred_depth.dtype)
1177
+ # pred_depth = unpatchify(
1178
+ # pred_depth,
1179
+ # self.config.patch_size,
1180
+ # 1,
1181
+ # (self.config.image_size, self.config.image_size),
1182
+ # )
1183
+
1184
+ # plt_rgb = (
1185
+ # torch.cat([rgb.float(), masked_rgb.float(), pred_rgb.float()], 2) * 0.5 + 0.5
1186
+ # ).clip(0, 1)
1187
+ # plt_depth = (
1188
+ # torch.cat([depth.float(), masked_depth.float(), pred_depth.float()], 2) / 2.0
1189
+ # ).clip(0, 1)
1190
+
1191
+ # return (
1192
+ # plt_rgb.permute(0, 2, 3, 1).cpu(),
1193
+ # plt_depth.permute(0, 2, 3, 1).cpu(),
1194
+ # plt_pc.cpu(),
1195
+ # )