|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
|
|
|
class MVAEConfig(PretrainedConfig): |
|
|
model_type = "mvae" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
prosoro_type="cylinder", |
|
|
x_dim_dict=None, |
|
|
h1_dim_dict=None, |
|
|
h2_dim_dict=None, |
|
|
z_dim=32, |
|
|
layer_norm=False, |
|
|
use_activation="relu", |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.prosoro_type = prosoro_type |
|
|
self.x_dim_dict = x_dim_dict |
|
|
self.h1_dim_dict = h1_dim_dict |
|
|
self.h2_dim_dict = h2_dim_dict |
|
|
self.z_dim = z_dim |
|
|
self.layer_norm = layer_norm |
|
|
self.use_activation = use_activation |
|
|
|
|
|
|
|
|
class MVAE(PreTrainedModel): |
|
|
config_class = MVAEConfig |
|
|
|
|
|
def __init__(self, config: MVAEConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
self.prosoro_type = getattr(config, "prosoro_type", "cylinder") |
|
|
self.x_dim_list = config.x_dim_dict |
|
|
self.h1_dim_list = config.h1_dim_dict |
|
|
self.h2_dim_list = config.h2_dim_dict |
|
|
self.z_dim = config.z_dim |
|
|
|
|
|
self.model = nn.ModuleDict() |
|
|
self.model = nn.ModuleDict() |
|
|
for i in range(len(self.x_dim_list)): |
|
|
self.model[f"encoder_{i}"] = nn.Sequential( |
|
|
nn.Linear(self.x_dim_list[i], self.h1_dim_list[i]), |
|
|
nn.ReLU(), |
|
|
nn.Linear(self.h1_dim_list[i], self.h2_dim_list[i]), |
|
|
) |
|
|
self.model[f"encoder_{i}"] = nn.Sequential( |
|
|
nn.Linear(self.x_dim_list[i], self.h1_dim_list[i]), |
|
|
nn.ReLU(), |
|
|
nn.Linear(self.h1_dim_list[i], self.h2_dim_list[i]), |
|
|
) |
|
|
self.model[f"decoder_{i}"] = nn.Sequential( |
|
|
nn.Linear(self.z_dim, self.h2_dim_list[i]), |
|
|
nn.ReLU(), |
|
|
nn.Linear(self.h2_dim_list[i], self.h1_dim_list[i]), |
|
|
nn.ReLU(), |
|
|
nn.Linear(self.h1_dim_list[i], self.x_dim_list[i]), |
|
|
) |
|
|
self.model[f"decoder_{i}"] = nn.Sequential( |
|
|
nn.Linear(self.z_dim, self.h2_dim_list[i]), |
|
|
nn.ReLU(), |
|
|
nn.Linear(self.h2_dim_list[i], self.h1_dim_list[i]), |
|
|
nn.ReLU(), |
|
|
nn.Linear(self.h1_dim_list[i], self.x_dim_list[i]), |
|
|
) |
|
|
self.model[f"fc_mu_{i}"] = nn.Linear(self.h2_dim_list[i], self.z_dim) |
|
|
self.model[f"fc_var_{i}"] = nn.Linear(self.h2_dim_list[i], self.z_dim) |
|
|
self.model[f"fc_mu_{i}"] = nn.Linear(self.h2_dim_list[i], self.z_dim) |
|
|
self.model[f"fc_var_{i}"] = nn.Linear(self.h2_dim_list[i], self.z_dim) |
|
|
|
|
|
def sample(self, mu, var): |
|
|
std = torch.exp(var / 2) |
|
|
p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std)) |
|
|
q = torch.distributions.Normal(mu, std) |
|
|
z = q.rsample() |
|
|
return p, q, z |
|
|
|
|
|
def x_to_z_encoder(self, x, input_index): |
|
|
h = self.model[f"encoder_{input_index}"](x) |
|
|
mu = self.model[f"fc_mu_{input_index}"](h) |
|
|
var = self.model[f"fc_var_{input_index}"](h) |
|
|
h = self.model[f"encoder_{input_index}"](x) |
|
|
mu = self.model[f"fc_mu_{input_index}"](h) |
|
|
var = self.model[f"fc_var_{input_index}"](h) |
|
|
_, _, z = self.sample(mu, var) |
|
|
return z |
|
|
|
|
|
def z_to_x_decoder(self, z, output_index): |
|
|
x_hat = self.model[f"decoder_{output_index}"](z) |
|
|
x_hat = self.model[f"decoder_{output_index}"](z) |
|
|
return x_hat |
|
|
|
|
|
def forward(self, x): |
|
|
x_hat_list = [] |
|
|
for i in range(len(self.x_dim_list)): |
|
|
z = self.x_to_z_encoder(x, 0) |
|
|
x_hat = self.z_to_x_decoder(z, i) |
|
|
x_hat_list.append(x_hat) |
|
|
return x_hat_list[1:] |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
|
|
config = kwargs.get("config", None) |
|
|
|
|
|
if config is None: |
|
|
from transformers import AutoConfig |
|
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path) |
|
|
|
|
|
prosoro_type = getattr(config, "prosoro_type", None) |
|
|
|
|
|
pretrained_model_name_or_path = pretrained_model_name_or_path + f"/{prosoro_type}" |
|
|
|
|
|
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) |
|
|
|