|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from models.rdt.blocks import (FinalLayer, RDTBlock, TimestepEmbedder, |
|
get_1d_sincos_pos_embed_from_grid, |
|
get_multimodal_cond_pos_embed) |
|
|
|
|
|
class RDT(nn.Module): |
|
""" |
|
Class for Robotics Diffusion Transformers. |
|
""" |
|
def __init__( |
|
self, |
|
output_dim=128, |
|
horizon=32, |
|
hidden_size=1152, |
|
depth=28, |
|
num_heads=16, |
|
max_lang_cond_len=1024, |
|
img_cond_len=4096, |
|
lang_pos_embed_config=None, |
|
img_pos_embed_config=None, |
|
dtype=torch.bfloat16 |
|
): |
|
super().__init__() |
|
self.horizon = horizon |
|
self.hidden_size = hidden_size |
|
self.max_lang_cond_len = max_lang_cond_len |
|
self.img_cond_len = img_cond_len |
|
self.dtype = dtype |
|
self.lang_pos_embed_config = lang_pos_embed_config |
|
self.img_pos_embed_config = img_pos_embed_config |
|
|
|
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype) |
|
self.freq_embedder = TimestepEmbedder(hidden_size, dtype=dtype) |
|
|
|
|
|
|
|
self.x_pos_embed = nn.Parameter( |
|
torch.zeros(1, horizon+3, hidden_size)) |
|
|
|
self.lang_cond_pos_embed = nn.Parameter( |
|
torch.zeros(1, max_lang_cond_len, hidden_size)) |
|
|
|
self.img_cond_pos_embed = nn.Parameter( |
|
torch.zeros(1, img_cond_len, hidden_size)) |
|
|
|
self.blocks = nn.ModuleList([ |
|
RDTBlock(hidden_size, num_heads) for _ in range(depth) |
|
]) |
|
self.final_layer = FinalLayer(hidden_size, output_dim) |
|
self.initialize_weights() |
|
|
|
def initialize_weights(self): |
|
|
|
def _basic_init(module): |
|
if isinstance(module, nn.Linear): |
|
torch.nn.init.xavier_uniform_(module.weight) |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
self.apply(_basic_init) |
|
|
|
|
|
x_pos_embed = get_multimodal_cond_pos_embed( |
|
embed_dim=self.hidden_size, |
|
mm_cond_lens=OrderedDict([ |
|
('timestep', 1), |
|
('ctrl_freq', 1), |
|
('state', 1), |
|
('action', self.horizon), |
|
]) |
|
) |
|
self.x_pos_embed.data.copy_(torch.from_numpy(x_pos_embed).float().unsqueeze(0)) |
|
|
|
if self.lang_pos_embed_config is None: |
|
lang_cond_pos_embed = get_1d_sincos_pos_embed_from_grid( |
|
self.hidden_size, torch.arange(self.max_lang_cond_len)) |
|
else: |
|
lang_cond_pos_embed = get_multimodal_cond_pos_embed( |
|
embed_dim=self.hidden_size, |
|
mm_cond_lens=OrderedDict(self.lang_pos_embed_config), |
|
embed_modality=False |
|
) |
|
self.lang_cond_pos_embed.data.copy_( |
|
torch.from_numpy(lang_cond_pos_embed).float().unsqueeze(0)) |
|
|
|
if self.img_pos_embed_config is None: |
|
img_cond_pos_embed = get_1d_sincos_pos_embed_from_grid( |
|
self.hidden_size, torch.arange(self.img_cond_len)) |
|
else: |
|
img_cond_pos_embed = get_multimodal_cond_pos_embed( |
|
embed_dim=self.hidden_size, |
|
mm_cond_lens=OrderedDict(self.img_pos_embed_config), |
|
embed_modality=False |
|
) |
|
self.img_cond_pos_embed.data.copy_( |
|
torch.from_numpy(img_cond_pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) |
|
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) |
|
nn.init.normal_(self.freq_embedder.mlp[0].weight, std=0.02) |
|
nn.init.normal_(self.freq_embedder.mlp[2].weight, std=0.02) |
|
|
|
|
|
nn.init.constant_(self.final_layer.ffn_final.fc2.weight, 0) |
|
nn.init.constant_(self.final_layer.ffn_final.fc2.bias, 0) |
|
|
|
|
|
self.to(self.dtype) |
|
|
|
def forward(self, x, freq, t, lang_c, img_c, lang_mask=None, img_mask=None): |
|
""" |
|
Forward pass of RDT. |
|
|
|
x: (B, T, D), state + action token sequence, T = horizon + 1, |
|
dimension D is assumed to be the same as the hidden size. |
|
freq: (B,), a scalar indicating control frequency. |
|
t: (B,) or (1,), diffusion timesteps. |
|
lang_c: (B, L_lang, D) or None, language condition tokens (variable length), |
|
dimension D is assumed to be the same as the hidden size. |
|
img_c: (B, L_img, D) or None, image condition tokens (fixed length), |
|
dimension D is assumed to be the same as the hidden size. |
|
lang_mask: (B, L_lang) or None, language condition mask (True for valid). |
|
img_mask: (B, L_img) or None, image condition mask (True for valid). |
|
""" |
|
t = self.t_embedder(t).unsqueeze(1) |
|
freq = self.freq_embedder(freq).unsqueeze(1) |
|
|
|
if t.shape[0] == 1: |
|
t = t.expand(x.shape[0], -1, -1) |
|
x = torch.cat([t, freq, x], dim=1) |
|
|
|
|
|
x = x + self.x_pos_embed |
|
|
|
lang_c = lang_c + self.lang_cond_pos_embed[:, :lang_c.shape[1]] |
|
img_c = img_c + self.img_cond_pos_embed |
|
|
|
|
|
conds = [lang_c, img_c] |
|
masks = [lang_mask, img_mask] |
|
for i, block in enumerate(self.blocks): |
|
c, mask = conds[i%2], masks[i%2] |
|
x = block(x, c, mask) |
|
|
|
x = self.final_layer(x) |
|
|
|
|
|
x = x[:, -self.horizon:] |
|
return x |
|
|