euijinrnd's picture
Add files using upload-large-folder tool
9de9fbf verified
import re
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.schedulers.scheduling_dpmsolver_multistep import \
DPMSolverMultistepScheduler
from models.hub_mixin import CompatiblePyTorchModelHubMixin
from models.rdt.model import RDT
class RDTRunner(
nn.Module,
CompatiblePyTorchModelHubMixin,
repo_url="https://huggingface.co/robotics-diffusion-transformer/rdt-1b"
):
def __init__(self, *, action_dim, pred_horizon, config,
lang_token_dim, img_token_dim, state_token_dim,
max_lang_cond_len, img_cond_len, lang_pos_embed_config=None,
img_pos_embed_config=None, dtype=torch.bfloat16):
super(RDTRunner, self).__init__()
# Create diffusion model
hidden_size = config['rdt']['hidden_size']
self.model = RDT(
output_dim=action_dim,
horizon=pred_horizon,
hidden_size=hidden_size,
depth=config['rdt']['depth'],
num_heads=config['rdt']['num_heads'],
max_lang_cond_len=max_lang_cond_len,
img_cond_len=img_cond_len,
lang_pos_embed_config=lang_pos_embed_config,
img_pos_embed_config=img_pos_embed_config,
dtype=dtype,
)
# Create adpators for various conditional inputs
self.lang_adaptor = self.build_condition_adapter(
config['lang_adaptor'],
in_features=lang_token_dim,
out_features=hidden_size
)
self.img_adaptor = self.build_condition_adapter(
config['img_adaptor'],
in_features=img_token_dim,
out_features=hidden_size
)
# A `state` refers to an action or a proprioception vector
self.state_adaptor = self.build_condition_adapter(
config['state_adaptor'],
in_features=state_token_dim * 2, # state + state mask (indicator)
out_features=hidden_size
)
# Create the noise scheduler
noise_scheduler_config = config['noise_scheduler']
self.noise_scheduler = DDPMScheduler(
num_train_timesteps=noise_scheduler_config['num_train_timesteps'],
beta_schedule=noise_scheduler_config['beta_schedule'],
prediction_type=noise_scheduler_config['prediction_type'],
clip_sample=noise_scheduler_config['clip_sample'],
)
self.noise_scheduler_sample = DPMSolverMultistepScheduler(
num_train_timesteps=noise_scheduler_config['num_train_timesteps'],
beta_schedule=noise_scheduler_config['beta_schedule'],
prediction_type=noise_scheduler_config['prediction_type'],
)
self.num_train_timesteps = noise_scheduler_config['num_train_timesteps']
self.num_inference_timesteps = noise_scheduler_config['num_inference_timesteps']
self.prediction_type = noise_scheduler_config['prediction_type']
self.pred_horizon = pred_horizon
self.action_dim = action_dim
print("Diffusion params: %e" % sum(
[p.numel() for p in self.model.parameters()] +
[p.numel() for p in self.lang_adaptor.parameters()] +
[p.numel() for p in self.img_adaptor.parameters()] +
[p.numel() for p in self.state_adaptor.parameters()]))
def build_condition_adapter(
self, projector_type, in_features, out_features):
projector = None
if projector_type == 'linear':
projector = nn.Linear(in_features, out_features)
else:
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(in_features, out_features)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU(approximate="tanh"))
modules.append(nn.Linear(out_features, out_features))
projector = nn.Sequential(*modules)
if projector is None:
raise ValueError(f'Unknown projector type: {projector_type}')
return projector
def adapt_conditions(self, lang_tokens, img_tokens, state_tokens):
'''
lang_tokens: (batch_size, lang_len, lang_token_dim)
img_tokens: (batch_size, img_len, img_token_dim)
state_tokens: (batch_size, state_len, state_token_dim)
return: adpated (..., hidden_size) for all input tokens
'''
adpated_lang = self.lang_adaptor(lang_tokens)
adpated_img = self.img_adaptor(img_tokens)
adpated_state = self.state_adaptor(state_tokens)
return adpated_lang, adpated_img, adpated_state
def conditional_sample(self, lang_cond, lang_attn_mask, img_cond,
state_traj, action_mask, ctrl_freqs):
'''
lang_cond: language conditional data, (batch_size, lang_len, hidden_size).
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
which should be True-False bool tensor.
img_cond: image conditional data, (batch_size, img_len, hidden_size).
state_traj: (batch_size, 1, hidden_size), state trajectory.
action_mask: (batch_size, 1, action_dim), a 0-1 **float** tensor
indicating the valid action dimensions.
ctrl_freqs: (batch_size,), control frequency for each sample.
return: (batch_size, horizon, action_dim)
'''
device = state_traj.device
dtype = state_traj.dtype
noisy_action = torch.randn(
size=(state_traj.shape[0], self.pred_horizon, self.action_dim),
dtype=dtype, device=device)
action_mask = action_mask.expand(-1, self.pred_horizon, -1)
# Set step values
self.noise_scheduler_sample.set_timesteps(self.num_inference_timesteps)
for t in self.noise_scheduler_sample.timesteps:
# Prepare state-action trajectory
action_traj = torch.cat([noisy_action, action_mask], dim=2)
action_traj = self.state_adaptor(action_traj)
state_action_traj = torch.cat([state_traj, action_traj], dim=1)
# Predict the model output
model_output = self.model(state_action_traj, ctrl_freqs,
t.unsqueeze(-1).to(device),
lang_cond, img_cond, lang_mask=lang_attn_mask)
# Compute previous actions: x_t -> x_t-1
noisy_action = self.noise_scheduler_sample.step(
model_output, t, noisy_action).prev_sample
noisy_action = noisy_action.to(state_traj.dtype)
# Finally apply the action mask to mask invalid action dimensions
noisy_action = noisy_action * action_mask
return noisy_action
# ========= Train ============
def compute_loss(self, lang_tokens, lang_attn_mask, img_tokens,
state_tokens, action_gt, action_mask, ctrl_freqs
) -> torch.Tensor:
'''
lang_tokens: (batch_size, lang_len, lang_token_dim)
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
which should be True-False bool tensor.
img_tokens: (batch_size, img_len, img_token_dim)
state_tokens: (batch_size, 1, state_token_dim)
action_gt: (batch_size, horizon, state_token_dim), ground-truth actions for supervision
action_mask: (batch_size, 1, state_token_dim), a 0-1 **float** tensor.
ctrl_freqs: (batch_size,), control frequency for each sample.
return: loss_value, a scalar tensor
'''
batch_size = lang_tokens.shape[0]
device = lang_tokens.device
# Sample noise that we'll add to the actions
noise = torch.randn(
action_gt.shape, dtype=action_gt.dtype, device=device
)
# Sample random diffusion timesteps
timesteps = torch.randint(
0, self.num_train_timesteps,
(batch_size,), device=device
).long()
# Add noise to the clean actions according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_action = self.noise_scheduler.add_noise(
action_gt, noise, timesteps)
# Concatenate the state and action tokens to form the input sequence
state_action_traj = torch.cat([state_tokens, noisy_action], dim=1)
# Append the action mask to the input sequence
action_mask = action_mask.expand(-1, state_action_traj.shape[1], -1)
state_action_traj = torch.cat([state_action_traj, action_mask], dim=2)
# Align the dimension with the hidden size
lang_cond, img_cond, state_action_traj = self.adapt_conditions(
lang_tokens, img_tokens, state_action_traj)
# Predict the denoised result
pred = self.model(state_action_traj, ctrl_freqs,
timesteps, lang_cond, img_cond,
lang_mask=lang_attn_mask)
pred_type = self.prediction_type
if pred_type == 'epsilon':
target = noise
elif pred_type == 'sample':
target = action_gt
else:
raise ValueError(f"Unsupported prediction type {pred_type}")
loss = F.mse_loss(pred, target)
return loss
# ========= Inference ============
def predict_action(self, lang_tokens, lang_attn_mask, img_tokens, state_tokens,
action_mask, ctrl_freqs):
'''
lang_tokens: (batch_size, lang_len, lang_token_dim)
lang_attn_mask: (batch_size, lang_len), a mask for valid language tokens,
which should be True-False bool tensor.
img_tokens: (batch_size, img_len, img_token_dim)
state_tokens: (batch_size, 1, state_token_dim)
action_mask: (batch_size, 1, action_dim),
which should be a 0-1 **float** tensor.
ctrl_freqs: (batch_size,), control frequency for each sample.
return: (batch_size, horizon, action_dim), predicted action sequence
'''
# Prepare the state and conditions
state_tokens = torch.cat([state_tokens, action_mask], dim=2)
lang_cond, img_cond, state_traj = self.adapt_conditions(
lang_tokens, img_tokens, state_tokens)
# Run sampling
action_pred = self.conditional_sample(
lang_cond, lang_attn_mask, img_cond,
state_traj, action_mask, ctrl_freqs,
)
return action_pred
def forward(self, *args, **kwargs) -> torch.Tensor:
return self.compute_loss(*args, **kwargs)