import os import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig class MVAEConfig(PretrainedConfig): model_type = "mvae" def __init__( self, prosoro_type="cylinder", x_dim_dict=None, h1_dim_dict=None, h2_dim_dict=None, z_dim=32, layer_norm=False, use_activation="relu", **kwargs, ): super().__init__(**kwargs) self.prosoro_type = prosoro_type self.x_dim_dict = x_dim_dict self.h1_dim_dict = h1_dim_dict self.h2_dim_dict = h2_dim_dict self.z_dim = z_dim self.layer_norm = layer_norm self.use_activation = use_activation class MVAE(PreTrainedModel): config_class = MVAEConfig def __init__(self, config: MVAEConfig): super().__init__(config) self.prosoro_type = getattr(config, "prosoro_type", "cylinder") self.x_dim_list = config.x_dim_dict self.h1_dim_list = config.h1_dim_dict self.h2_dim_list = config.h2_dim_dict self.z_dim = config.z_dim self.model = nn.ModuleDict() self.model = nn.ModuleDict() for i in range(len(self.x_dim_list)): self.model[f"encoder_{i}"] = nn.Sequential( nn.Linear(self.x_dim_list[i], self.h1_dim_list[i]), nn.ReLU(), nn.Linear(self.h1_dim_list[i], self.h2_dim_list[i]), ) self.model[f"encoder_{i}"] = nn.Sequential( nn.Linear(self.x_dim_list[i], self.h1_dim_list[i]), nn.ReLU(), nn.Linear(self.h1_dim_list[i], self.h2_dim_list[i]), ) self.model[f"decoder_{i}"] = nn.Sequential( nn.Linear(self.z_dim, self.h2_dim_list[i]), nn.ReLU(), nn.Linear(self.h2_dim_list[i], self.h1_dim_list[i]), nn.ReLU(), nn.Linear(self.h1_dim_list[i], self.x_dim_list[i]), ) self.model[f"decoder_{i}"] = nn.Sequential( nn.Linear(self.z_dim, self.h2_dim_list[i]), nn.ReLU(), nn.Linear(self.h2_dim_list[i], self.h1_dim_list[i]), nn.ReLU(), nn.Linear(self.h1_dim_list[i], self.x_dim_list[i]), ) self.model[f"fc_mu_{i}"] = nn.Linear(self.h2_dim_list[i], self.z_dim) self.model[f"fc_var_{i}"] = nn.Linear(self.h2_dim_list[i], self.z_dim) self.model[f"fc_mu_{i}"] = nn.Linear(self.h2_dim_list[i], self.z_dim) self.model[f"fc_var_{i}"] = nn.Linear(self.h2_dim_list[i], self.z_dim) def sample(self, mu, var): std = torch.exp(var / 2) p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std)) q = torch.distributions.Normal(mu, std) z = q.rsample() return p, q, z def x_to_z_encoder(self, x, input_index): h = self.model[f"encoder_{input_index}"](x) mu = self.model[f"fc_mu_{input_index}"](h) var = self.model[f"fc_var_{input_index}"](h) h = self.model[f"encoder_{input_index}"](x) mu = self.model[f"fc_mu_{input_index}"](h) var = self.model[f"fc_var_{input_index}"](h) _, _, z = self.sample(mu, var) return z def z_to_x_decoder(self, z, output_index): x_hat = self.model[f"decoder_{output_index}"](z) x_hat = self.model[f"decoder_{output_index}"](z) return x_hat def forward(self, x): x_hat_list = [] for i in range(len(self.x_dim_list)): z = self.x_to_z_encoder(x, 0) x_hat = self.z_to_x_decoder(z, i) x_hat_list.append(x_hat) return x_hat_list[1:] # Return only force and shape @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): config = kwargs.get("config", None) if config is None: from transformers import AutoConfig config = AutoConfig.from_pretrained(pretrained_model_name_or_path) prosoro_type = getattr(config, "prosoro_type", None) pretrained_model_name_or_path = pretrained_model_name_or_path + f"/{prosoro_type}" return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)