prosoro-mvae / modeling.py
han-xudong
modified: modeling.py
0cabb09
import os
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
class MVAEConfig(PretrainedConfig):
model_type = "mvae"
def __init__(
self,
prosoro_type="cylinder",
x_dim_dict=None,
h1_dim_dict=None,
h2_dim_dict=None,
z_dim=32,
layer_norm=False,
use_activation="relu",
**kwargs,
):
super().__init__(**kwargs)
self.prosoro_type = prosoro_type
self.x_dim_dict = x_dim_dict
self.h1_dim_dict = h1_dim_dict
self.h2_dim_dict = h2_dim_dict
self.z_dim = z_dim
self.layer_norm = layer_norm
self.use_activation = use_activation
class MVAE(PreTrainedModel):
config_class = MVAEConfig
def __init__(self, config: MVAEConfig):
super().__init__(config)
self.prosoro_type = getattr(config, "prosoro_type", "cylinder")
self.x_dim_list = config.x_dim_dict
self.h1_dim_list = config.h1_dim_dict
self.h2_dim_list = config.h2_dim_dict
self.z_dim = config.z_dim
self.model = nn.ModuleDict()
self.model = nn.ModuleDict()
for i in range(len(self.x_dim_list)):
self.model[f"encoder_{i}"] = nn.Sequential(
nn.Linear(self.x_dim_list[i], self.h1_dim_list[i]),
nn.ReLU(),
nn.Linear(self.h1_dim_list[i], self.h2_dim_list[i]),
)
self.model[f"encoder_{i}"] = nn.Sequential(
nn.Linear(self.x_dim_list[i], self.h1_dim_list[i]),
nn.ReLU(),
nn.Linear(self.h1_dim_list[i], self.h2_dim_list[i]),
)
self.model[f"decoder_{i}"] = nn.Sequential(
nn.Linear(self.z_dim, self.h2_dim_list[i]),
nn.ReLU(),
nn.Linear(self.h2_dim_list[i], self.h1_dim_list[i]),
nn.ReLU(),
nn.Linear(self.h1_dim_list[i], self.x_dim_list[i]),
)
self.model[f"decoder_{i}"] = nn.Sequential(
nn.Linear(self.z_dim, self.h2_dim_list[i]),
nn.ReLU(),
nn.Linear(self.h2_dim_list[i], self.h1_dim_list[i]),
nn.ReLU(),
nn.Linear(self.h1_dim_list[i], self.x_dim_list[i]),
)
self.model[f"fc_mu_{i}"] = nn.Linear(self.h2_dim_list[i], self.z_dim)
self.model[f"fc_var_{i}"] = nn.Linear(self.h2_dim_list[i], self.z_dim)
self.model[f"fc_mu_{i}"] = nn.Linear(self.h2_dim_list[i], self.z_dim)
self.model[f"fc_var_{i}"] = nn.Linear(self.h2_dim_list[i], self.z_dim)
def sample(self, mu, var):
std = torch.exp(var / 2)
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
q = torch.distributions.Normal(mu, std)
z = q.rsample()
return p, q, z
def x_to_z_encoder(self, x, input_index):
h = self.model[f"encoder_{input_index}"](x)
mu = self.model[f"fc_mu_{input_index}"](h)
var = self.model[f"fc_var_{input_index}"](h)
h = self.model[f"encoder_{input_index}"](x)
mu = self.model[f"fc_mu_{input_index}"](h)
var = self.model[f"fc_var_{input_index}"](h)
_, _, z = self.sample(mu, var)
return z
def z_to_x_decoder(self, z, output_index):
x_hat = self.model[f"decoder_{output_index}"](z)
x_hat = self.model[f"decoder_{output_index}"](z)
return x_hat
def forward(self, x):
x_hat_list = []
for i in range(len(self.x_dim_list)):
z = self.x_to_z_encoder(x, 0)
x_hat = self.z_to_x_decoder(z, i)
x_hat_list.append(x_hat)
return x_hat_list[1:] # Return only force and shape
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
config = kwargs.get("config", None)
if config is None:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
prosoro_type = getattr(config, "prosoro_type", None)
pretrained_model_name_or_path = pretrained_model_name_or_path + f"/{prosoro_type}"
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)