|
import re |
|
from pathlib import Path |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler |
|
from diffusers.schedulers.scheduling_dpmsolver_multistep import \ |
|
DPMSolverMultistepScheduler |
|
|
|
from models.hub_mixin import CompatiblePyTorchModelHubMixin |
|
from models.rdt.model import RDT |
|
|
|
|
|
class RDTRunner( |
|
nn.Module, |
|
CompatiblePyTorchModelHubMixin, |
|
repo_url="https://huggingface.co/robotics-diffusion-transformer/rdt-1b" |
|
): |
|
def __init__(self, *, action_dim, pred_horizon, config, |
|
lang_token_dim, img_token_dim, state_token_dim, |
|
max_lang_cond_len, img_cond_len, lang_pos_embed_config=None, |
|
img_pos_embed_config=None, dtype=torch.bfloat16): |
|
super(RDTRunner, self).__init__() |
|
|
|
hidden_size = config['rdt']['hidden_size'] |
|
self.model = RDT( |
|
output_dim=action_dim, |
|
horizon=pred_horizon, |
|
hidden_size=hidden_size, |
|
depth=config['rdt']['depth'], |
|
num_heads=config['rdt']['num_heads'], |
|
max_lang_cond_len=max_lang_cond_len, |
|
img_cond_len=img_cond_len, |
|
lang_pos_embed_config=lang_pos_embed_config, |
|
img_pos_embed_config=img_pos_embed_config, |
|
dtype=dtype, |
|
) |
|
|
|
|
|
self.lang_adaptor = self.build_condition_adapter( |
|
config['lang_adaptor'], |
|
in_features=lang_token_dim, |
|
out_features=hidden_size |
|
) |
|
self.img_adaptor = self.build_condition_adapter( |
|
config['img_adaptor'], |
|
in_features=img_token_dim, |
|
out_features=hidden_size |
|
) |
|
|
|
self.state_adaptor = self.build_condition_adapter( |
|
config['state_adaptor'], |
|
in_features=state_token_dim * 2, |
|
out_features=hidden_size |
|
) |
|
|
|
|
|
noise_scheduler_config = config['noise_scheduler'] |
|
self.noise_scheduler = DDPMScheduler( |
|
num_train_timesteps=noise_scheduler_config['num_train_timesteps'], |
|
beta_schedule=noise_scheduler_config['beta_schedule'], |
|
prediction_type=noise_scheduler_config['prediction_type'], |
|
clip_sample=noise_scheduler_config['clip_sample'], |
|
) |
|
self.noise_scheduler_sample = DPMSolverMultistepScheduler( |
|
num_train_timesteps=noise_scheduler_config['num_train_timesteps'], |
|
beta_schedule=noise_scheduler_config['beta_schedule'], |
|
prediction_type=noise_scheduler_config['prediction_type'], |
|
) |
|
|
|
self.num_train_timesteps = noise_scheduler_config['num_train_timesteps'] |
|
self.num_inference_timesteps = noise_scheduler_config['num_inference_timesteps'] |
|
self.prediction_type = noise_scheduler_config['prediction_type'] |
|
|
|
self.pred_horizon = pred_horizon |
|
self.action_dim = action_dim |
|
|
|
print("Diffusion params: %e" % sum( |
|
[p.numel() for p in self.model.parameters()] + |
|
[p.numel() for p in self.lang_adaptor.parameters()] + |
|
[p.numel() for p in self.img_adaptor.parameters()] + |
|
[p.numel() for p in self.state_adaptor.parameters()])) |
|
|
|
def build_condition_adapter( |
|
self, projector_type, in_features, out_features): |
|
projector = None |
|
if projector_type == 'linear': |
|
projector = nn.Linear(in_features, out_features) |
|
else: |
|
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) |
|
if mlp_gelu_match: |
|
mlp_depth = int(mlp_gelu_match.group(1)) |
|
modules = [nn.Linear(in_features, out_features)] |
|
for _ in range(1, mlp_depth): |
|
modules.append(nn.GELU(approximate="tanh")) |
|
modules.append(nn.Linear(out_features, out_features)) |
|
projector = nn.Sequential(*modules) |
|
|
|
if projector is None: |
|
raise ValueError(f'Unknown projector type: {projector_type}') |
|
|
|
return projector |
|
|
|
def adapt_conditions(self, lang_tokens, img_tokens, state_tokens): |
|
''' |
|
lang_tokens: (batch_size, lang_len, lang_token_dim) |
|
img_tokens: (batch_size, img_len, img_token_dim) |
|
state_tokens: (batch_size, state_len, state_token_dim) |
|
|
|
return: adpated (..., hidden_size) for all input tokens |
|
''' |
|
adpated_lang = self.lang_adaptor(lang_tokens) |
|
adpated_img = self.img_adaptor(img_tokens) |
|
adpated_state = self.state_adaptor(state_tokens) |
|
|
|
return adpated_lang, adpated_img, adpated_state |
|
|
|
def conditional_sample(self, lang_cond, lang_attn_mask, img_cond, |
|
state_traj, action_mask, ctrl_freqs): |
|
''' |
|
lang_cond: language conditional data, (batch_size, lang_len, hidden_size). |
|
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens, |
|
which should be True-False bool tensor. |
|
img_cond: image conditional data, (batch_size, img_len, hidden_size). |
|
state_traj: (batch_size, 1, hidden_size), state trajectory. |
|
action_mask: (batch_size, 1, action_dim), a 0-1 **float** tensor |
|
indicating the valid action dimensions. |
|
ctrl_freqs: (batch_size,), control frequency for each sample. |
|
|
|
return: (batch_size, horizon, action_dim) |
|
''' |
|
device = state_traj.device |
|
dtype = state_traj.dtype |
|
noisy_action = torch.randn( |
|
size=(state_traj.shape[0], self.pred_horizon, self.action_dim), |
|
dtype=dtype, device=device) |
|
action_mask = action_mask.expand(-1, self.pred_horizon, -1) |
|
|
|
|
|
self.noise_scheduler_sample.set_timesteps(self.num_inference_timesteps) |
|
|
|
for t in self.noise_scheduler_sample.timesteps: |
|
|
|
action_traj = torch.cat([noisy_action, action_mask], dim=2) |
|
action_traj = self.state_adaptor(action_traj) |
|
state_action_traj = torch.cat([state_traj, action_traj], dim=1) |
|
|
|
|
|
model_output = self.model(state_action_traj, ctrl_freqs, |
|
t.unsqueeze(-1).to(device), |
|
lang_cond, img_cond, lang_mask=lang_attn_mask) |
|
|
|
|
|
noisy_action = self.noise_scheduler_sample.step( |
|
model_output, t, noisy_action).prev_sample |
|
noisy_action = noisy_action.to(state_traj.dtype) |
|
|
|
|
|
noisy_action = noisy_action * action_mask |
|
|
|
return noisy_action |
|
|
|
|
|
def compute_loss(self, lang_tokens, lang_attn_mask, img_tokens, |
|
state_tokens, action_gt, action_mask, ctrl_freqs |
|
) -> torch.Tensor: |
|
''' |
|
lang_tokens: (batch_size, lang_len, lang_token_dim) |
|
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens, |
|
which should be True-False bool tensor. |
|
img_tokens: (batch_size, img_len, img_token_dim) |
|
state_tokens: (batch_size, 1, state_token_dim) |
|
action_gt: (batch_size, horizon, state_token_dim), ground-truth actions for supervision |
|
action_mask: (batch_size, 1, state_token_dim), a 0-1 **float** tensor. |
|
ctrl_freqs: (batch_size,), control frequency for each sample. |
|
|
|
return: loss_value, a scalar tensor |
|
''' |
|
batch_size = lang_tokens.shape[0] |
|
device = lang_tokens.device |
|
|
|
|
|
noise = torch.randn( |
|
action_gt.shape, dtype=action_gt.dtype, device=device |
|
) |
|
|
|
timesteps = torch.randint( |
|
0, self.num_train_timesteps, |
|
(batch_size,), device=device |
|
).long() |
|
|
|
|
|
noisy_action = self.noise_scheduler.add_noise( |
|
action_gt, noise, timesteps) |
|
|
|
|
|
state_action_traj = torch.cat([state_tokens, noisy_action], dim=1) |
|
|
|
action_mask = action_mask.expand(-1, state_action_traj.shape[1], -1) |
|
state_action_traj = torch.cat([state_action_traj, action_mask], dim=2) |
|
|
|
lang_cond, img_cond, state_action_traj = self.adapt_conditions( |
|
lang_tokens, img_tokens, state_action_traj) |
|
|
|
pred = self.model(state_action_traj, ctrl_freqs, |
|
timesteps, lang_cond, img_cond, |
|
lang_mask=lang_attn_mask) |
|
|
|
pred_type = self.prediction_type |
|
if pred_type == 'epsilon': |
|
target = noise |
|
elif pred_type == 'sample': |
|
target = action_gt |
|
else: |
|
raise ValueError(f"Unsupported prediction type {pred_type}") |
|
|
|
loss = F.mse_loss(pred, target) |
|
return loss |
|
|
|
|
|
def predict_action(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens, |
|
action_mask, ctrl_freqs): |
|
''' |
|
lang_tokens: (batch_size, lang_len, lang_token_dim) |
|
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens, |
|
which should be True-False bool tensor. |
|
img_tokens: (batch_size, img_len, img_token_dim) |
|
state_tokens: (batch_size, 1, state_token_dim) |
|
action_mask: (batch_size, 1, action_dim), |
|
which should be a 0-1 **float** tensor. |
|
ctrl_freqs: (batch_size,), control frequency for each sample. |
|
|
|
return: (batch_size, horizon, action_dim), predicted action sequence |
|
''' |
|
|
|
state_tokens = torch.cat([state_tokens, action_mask], dim=2) |
|
lang_cond, img_cond, state_traj = self.adapt_conditions( |
|
lang_tokens, img_tokens, state_tokens) |
|
|
|
|
|
action_pred = self.conditional_sample( |
|
lang_cond, lang_attn_mask, img_cond, |
|
state_traj, action_mask, ctrl_freqs, |
|
) |
|
|
|
return action_pred |
|
|
|
def forward(self, *args, **kwargs) -> torch.Tensor: |
|
return self.compute_loss(*args, **kwargs) |
|
|