|
""" |
|
PyTorch Autoencoder model for Hugging Face Transformers. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import Optional, Tuple, Union, Dict, Any, List |
|
from dataclasses import dataclass |
|
import random |
|
|
|
from transformers import PreTrainedModel |
|
from transformers.modeling_outputs import BaseModelOutput |
|
from transformers.utils import ModelOutput |
|
|
|
from configuration_autoencoder import AutoencoderConfig |
|
|
|
|
|
class NeuralScaler(nn.Module): |
|
"""Learnable alternative to StandardScaler using neural networks.""" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__() |
|
self.config = config |
|
input_dim = config.input_dim |
|
hidden_dim = config.preprocessing_hidden_dim |
|
|
|
|
|
self.mean_estimator = nn.Sequential( |
|
nn.Linear(input_dim, hidden_dim), |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim, hidden_dim), |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim, input_dim) |
|
) |
|
|
|
self.std_estimator = nn.Sequential( |
|
nn.Linear(input_dim, hidden_dim), |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim, hidden_dim), |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim, input_dim), |
|
nn.Softplus() |
|
) |
|
|
|
|
|
self.weight = nn.Parameter(torch.ones(input_dim)) |
|
self.bias = nn.Parameter(torch.zeros(input_dim)) |
|
|
|
|
|
self.register_buffer('running_mean', torch.zeros(input_dim)) |
|
self.register_buffer('running_std', torch.ones(input_dim)) |
|
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) |
|
|
|
|
|
self.momentum = 0.1 |
|
|
|
def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Forward pass through neural scaler. |
|
|
|
Args: |
|
x: Input tensor (2D or 3D) |
|
inverse: Whether to apply inverse transformation |
|
|
|
Returns: |
|
Tuple of (transformed_tensor, regularization_loss) |
|
""" |
|
if inverse: |
|
return self._inverse_transform(x) |
|
|
|
|
|
original_shape = x.shape |
|
if x.dim() == 3: |
|
|
|
x = x.view(-1, x.size(-1)) |
|
|
|
if self.training: |
|
|
|
batch_mean = x.mean(dim=0, keepdim=True) |
|
batch_std = x.std(dim=0, keepdim=True) |
|
|
|
|
|
learned_mean_adj = self.mean_estimator(batch_mean) |
|
learned_std_adj = self.std_estimator(batch_std) |
|
|
|
|
|
effective_mean = batch_mean + learned_mean_adj |
|
effective_std = batch_std + learned_std_adj + 1e-8 |
|
|
|
|
|
with torch.no_grad(): |
|
self.num_batches_tracked += 1 |
|
if self.num_batches_tracked == 1: |
|
self.running_mean.copy_(batch_mean.squeeze()) |
|
self.running_std.copy_(batch_std.squeeze()) |
|
else: |
|
self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum) |
|
self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum) |
|
else: |
|
|
|
effective_mean = self.running_mean.unsqueeze(0) |
|
effective_std = self.running_std.unsqueeze(0) + 1e-8 |
|
|
|
|
|
normalized = (x - effective_mean) / effective_std |
|
|
|
|
|
transformed = normalized * self.weight + self.bias |
|
|
|
|
|
if len(original_shape) == 3: |
|
transformed = transformed.view(original_shape) |
|
|
|
|
|
reg_loss = 0.01 * (self.weight.var() + self.bias.var()) |
|
|
|
return transformed, reg_loss |
|
|
|
def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Apply inverse transformation to get back original scale.""" |
|
if not self.config.learn_inverse_preprocessing: |
|
return x, torch.tensor(0.0, device=x.device) |
|
|
|
|
|
original_shape = x.shape |
|
if x.dim() == 3: |
|
|
|
x = x.view(-1, x.size(-1)) |
|
|
|
|
|
x = (x - self.bias) / (self.weight + 1e-8) |
|
|
|
|
|
effective_mean = self.running_mean.unsqueeze(0) |
|
effective_std = self.running_std.unsqueeze(0) + 1e-8 |
|
x = x * effective_std + effective_mean |
|
|
|
|
|
if len(original_shape) == 3: |
|
x = x.view(original_shape) |
|
|
|
return x, torch.tensor(0.0, device=x.device) |
|
|
|
|
|
class CouplingLayer(nn.Module): |
|
"""Coupling layer for normalizing flows.""" |
|
|
|
def __init__(self, input_dim: int, hidden_dim: int = 64, mask_type: str = "alternating"): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
self.hidden_dim = hidden_dim |
|
|
|
|
|
if mask_type == "alternating": |
|
self.register_buffer('mask', torch.arange(input_dim) % 2) |
|
elif mask_type == "half": |
|
mask = torch.zeros(input_dim) |
|
mask[:input_dim // 2] = 1 |
|
self.register_buffer('mask', mask) |
|
else: |
|
raise ValueError(f"Unknown mask type: {mask_type}") |
|
|
|
|
|
masked_dim = int(self.mask.sum().item()) |
|
unmasked_dim = input_dim - masked_dim |
|
|
|
self.scale_net = nn.Sequential( |
|
nn.Linear(masked_dim, hidden_dim), |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim, hidden_dim), |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim, unmasked_dim), |
|
nn.Tanh() |
|
) |
|
|
|
self.translate_net = nn.Sequential( |
|
nn.Linear(masked_dim, hidden_dim), |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim, hidden_dim), |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim, unmasked_dim) |
|
) |
|
|
|
def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Forward pass through coupling layer. |
|
|
|
Args: |
|
x: Input tensor |
|
inverse: Whether to apply inverse transformation |
|
|
|
Returns: |
|
Tuple of (transformed_tensor, log_determinant) |
|
""" |
|
mask = self.mask.bool() |
|
x_masked = x[:, mask] |
|
x_unmasked = x[:, ~mask] |
|
|
|
|
|
s = self.scale_net(x_masked) |
|
t = self.translate_net(x_masked) |
|
|
|
if not inverse: |
|
|
|
y_unmasked = x_unmasked * torch.exp(s) + t |
|
log_det = s.sum(dim=1) |
|
else: |
|
|
|
y_unmasked = (x_unmasked - t) * torch.exp(-s) |
|
log_det = -s.sum(dim=1) |
|
|
|
|
|
y = torch.zeros_like(x) |
|
y[:, mask] = x_masked |
|
y[:, ~mask] = y_unmasked |
|
|
|
return y, log_det |
|
|
|
|
|
class NormalizingFlowPreprocessor(nn.Module): |
|
"""Normalizing flow for learnable data preprocessing.""" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__() |
|
self.config = config |
|
input_dim = config.input_dim |
|
hidden_dim = config.preprocessing_hidden_dim |
|
num_layers = config.flow_coupling_layers |
|
|
|
|
|
self.layers = nn.ModuleList() |
|
for i in range(num_layers): |
|
mask_type = "alternating" if i % 2 == 0 else "half" |
|
self.layers.append(CouplingLayer(input_dim, hidden_dim, mask_type)) |
|
|
|
|
|
if config.use_batch_norm: |
|
self.batch_norms = nn.ModuleList([ |
|
nn.BatchNorm1d(input_dim) for _ in range(num_layers - 1) |
|
]) |
|
else: |
|
self.batch_norms = None |
|
|
|
def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Forward pass through normalizing flow. |
|
|
|
Args: |
|
x: Input tensor (2D or 3D) |
|
inverse: Whether to apply inverse transformation |
|
|
|
Returns: |
|
Tuple of (transformed_tensor, total_log_determinant) |
|
""" |
|
|
|
original_shape = x.shape |
|
if x.dim() == 3: |
|
|
|
x = x.view(-1, x.size(-1)) |
|
|
|
log_det_total = torch.zeros(x.size(0), device=x.device) |
|
|
|
if not inverse: |
|
|
|
for i, layer in enumerate(self.layers): |
|
x, log_det = layer(x, inverse=False) |
|
log_det_total += log_det |
|
|
|
|
|
if self.batch_norms and i < len(self.layers) - 1: |
|
x = self.batch_norms[i](x) |
|
else: |
|
|
|
for i, layer in enumerate(reversed(self.layers)): |
|
|
|
if self.batch_norms and i > 0: |
|
|
|
bn_idx = len(self.layers) - 1 - i |
|
x = self.batch_norms[bn_idx](x) |
|
|
|
x, log_det = layer(x, inverse=True) |
|
log_det_total += log_det |
|
|
|
|
|
if len(original_shape) == 3: |
|
x = x.view(original_shape) |
|
|
|
|
|
|
|
reg_loss = 0.01 * log_det_total.abs().mean() |
|
|
|
return x, reg_loss |
|
|
|
|
|
class LearnablePreprocessor(nn.Module): |
|
"""Unified interface for learnable preprocessing methods.""" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__() |
|
self.config = config |
|
|
|
if not config.has_preprocessing: |
|
self.preprocessor = nn.Identity() |
|
elif config.is_neural_scaler: |
|
self.preprocessor = NeuralScaler(config) |
|
elif config.is_normalizing_flow: |
|
self.preprocessor = NormalizingFlowPreprocessor(config) |
|
else: |
|
raise ValueError(f"Unknown preprocessing type: {config.preprocessing_type}") |
|
|
|
def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Apply preprocessing transformation. |
|
|
|
Args: |
|
x: Input tensor |
|
inverse: Whether to apply inverse transformation |
|
|
|
Returns: |
|
Tuple of (transformed_tensor, regularization_loss) |
|
""" |
|
if isinstance(self.preprocessor, nn.Identity): |
|
return x, torch.tensor(0.0, device=x.device) |
|
|
|
return self.preprocessor(x, inverse=inverse) |
|
|
|
|
|
@dataclass |
|
class AutoencoderOutput(ModelOutput): |
|
""" |
|
Output type of AutoencoderModel. |
|
|
|
Args: |
|
last_hidden_state (torch.FloatTensor): The latent representation of the input. |
|
reconstructed (torch.FloatTensor, optional): The reconstructed input. |
|
hidden_states (tuple(torch.FloatTensor), optional): Hidden states of the encoder layers. |
|
attentions (tuple(torch.FloatTensor), optional): Not used in basic autoencoder. |
|
preprocessing_loss (torch.FloatTensor, optional): Loss from learnable preprocessing. |
|
""" |
|
|
|
last_hidden_state: torch.FloatTensor = None |
|
reconstructed: Optional[torch.FloatTensor] = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
preprocessing_loss: Optional[torch.FloatTensor] = None |
|
|
|
|
|
@dataclass |
|
class AutoencoderForReconstructionOutput(ModelOutput): |
|
""" |
|
Output type of AutoencoderForReconstruction. |
|
|
|
Args: |
|
loss (torch.FloatTensor, optional): The reconstruction loss. |
|
reconstructed (torch.FloatTensor): The reconstructed input. |
|
last_hidden_state (torch.FloatTensor): The latent representation. |
|
hidden_states (tuple(torch.FloatTensor), optional): Hidden states of the encoder layers. |
|
preprocessing_loss (torch.FloatTensor, optional): Loss from learnable preprocessing. |
|
""" |
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
reconstructed: torch.FloatTensor = None |
|
last_hidden_state: torch.FloatTensor = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
preprocessing_loss: Optional[torch.FloatTensor] = None |
|
|
|
|
|
class AutoencoderEncoder(nn.Module): |
|
"""Encoder part of the autoencoder.""" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__() |
|
self.config = config |
|
|
|
|
|
layers = [] |
|
input_dim = config.input_dim |
|
|
|
for hidden_dim in config.hidden_dims: |
|
layers.append(nn.Linear(input_dim, hidden_dim)) |
|
|
|
if config.use_batch_norm: |
|
layers.append(nn.BatchNorm1d(hidden_dim)) |
|
|
|
layers.append(self._get_activation(config.activation)) |
|
|
|
if config.dropout_rate > 0: |
|
layers.append(nn.Dropout(config.dropout_rate)) |
|
|
|
input_dim = hidden_dim |
|
|
|
self.encoder = nn.Sequential(*layers) |
|
|
|
|
|
if config.is_variational: |
|
self.fc_mu = nn.Linear(input_dim, config.latent_dim) |
|
self.fc_logvar = nn.Linear(input_dim, config.latent_dim) |
|
else: |
|
|
|
self.fc_out = nn.Linear(input_dim, config.latent_dim) |
|
|
|
def _get_activation(self, activation: str) -> nn.Module: |
|
"""Get activation function by name.""" |
|
activations = { |
|
"relu": nn.ReLU(), |
|
"tanh": nn.Tanh(), |
|
"sigmoid": nn.Sigmoid(), |
|
"leaky_relu": nn.LeakyReLU(), |
|
"gelu": nn.GELU(), |
|
"swish": nn.SiLU(), |
|
"silu": nn.SiLU(), |
|
"elu": nn.ELU(), |
|
"prelu": nn.PReLU(), |
|
"relu6": nn.ReLU6(), |
|
"hardtanh": nn.Hardtanh(), |
|
"hardsigmoid": nn.Hardsigmoid(), |
|
"hardswish": nn.Hardswish(), |
|
"mish": nn.Mish(), |
|
"softplus": nn.Softplus(), |
|
"softsign": nn.Softsign(), |
|
"tanhshrink": nn.Tanhshrink(), |
|
"threshold": nn.Threshold(threshold=0.1, value=0), |
|
} |
|
return activations[activation] |
|
|
|
def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: |
|
"""Forward pass through encoder.""" |
|
|
|
if self.config.is_denoising and self.training: |
|
noise = torch.randn_like(x) * self.config.noise_factor |
|
x = x + noise |
|
|
|
encoded = self.encoder(x) |
|
|
|
if self.config.is_variational: |
|
|
|
mu = self.fc_mu(encoded) |
|
logvar = self.fc_logvar(encoded) |
|
|
|
|
|
if self.training: |
|
std = torch.exp(0.5 * logvar) |
|
eps = torch.randn_like(std) |
|
z = mu + eps * std |
|
else: |
|
z = mu |
|
|
|
return z, mu, logvar |
|
else: |
|
|
|
latent = self.fc_out(encoded) |
|
|
|
|
|
if self.config.is_sparse and self.training: |
|
|
|
latent = F.relu(latent) |
|
|
|
return latent |
|
|
|
|
|
class AutoencoderDecoder(nn.Module): |
|
"""Decoder part of the autoencoder.""" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__() |
|
self.config = config |
|
|
|
|
|
layers = [] |
|
input_dim = config.latent_dim |
|
decoder_dims = config.decoder_dims + [config.input_dim] |
|
|
|
for i, hidden_dim in enumerate(decoder_dims): |
|
layers.append(nn.Linear(input_dim, hidden_dim)) |
|
|
|
|
|
if i < len(decoder_dims) - 1: |
|
if config.use_batch_norm: |
|
layers.append(nn.BatchNorm1d(hidden_dim)) |
|
|
|
layers.append(self._get_activation(config.activation)) |
|
|
|
if config.dropout_rate > 0: |
|
layers.append(nn.Dropout(config.dropout_rate)) |
|
else: |
|
|
|
if config.reconstruction_loss == "bce": |
|
layers.append(nn.Sigmoid()) |
|
|
|
input_dim = hidden_dim |
|
|
|
self.decoder = nn.Sequential(*layers) |
|
|
|
def _get_activation(self, activation: str) -> nn.Module: |
|
"""Get activation function by name.""" |
|
activations = { |
|
"relu": nn.ReLU(), |
|
"tanh": nn.Tanh(), |
|
"sigmoid": nn.Sigmoid(), |
|
"leaky_relu": nn.LeakyReLU(), |
|
"gelu": nn.GELU(), |
|
"swish": nn.SiLU(), |
|
"silu": nn.SiLU(), |
|
"elu": nn.ELU(), |
|
"prelu": nn.PReLU(), |
|
"relu6": nn.ReLU6(), |
|
"hardtanh": nn.Hardtanh(), |
|
"hardsigmoid": nn.Hardsigmoid(), |
|
"hardswish": nn.Hardswish(), |
|
"mish": nn.Mish(), |
|
"softplus": nn.Softplus(), |
|
"softsign": nn.Softsign(), |
|
"tanhshrink": nn.Tanhshrink(), |
|
"threshold": nn.Threshold(threshold=0.1, value=0), |
|
} |
|
return activations[activation] |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Forward pass through decoder.""" |
|
return self.decoder(x) |
|
|
|
|
|
class RecurrentEncoder(nn.Module): |
|
"""Recurrent encoder for sequence data.""" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__() |
|
self.config = config |
|
|
|
|
|
if config.rnn_type == "lstm": |
|
rnn_class = nn.LSTM |
|
elif config.rnn_type == "gru": |
|
rnn_class = nn.GRU |
|
elif config.rnn_type == "rnn": |
|
rnn_class = nn.RNN |
|
else: |
|
raise ValueError(f"Unknown RNN type: {config.rnn_type}") |
|
|
|
|
|
self.rnn = rnn_class( |
|
input_size=config.input_dim, |
|
hidden_size=config.latent_dim, |
|
num_layers=config.num_layers, |
|
batch_first=True, |
|
dropout=config.dropout_rate if config.num_layers > 1 else 0, |
|
bidirectional=config.bidirectional |
|
) |
|
|
|
|
|
if config.bidirectional: |
|
self.projection = nn.Linear(config.latent_dim * 2, config.latent_dim) |
|
else: |
|
self.projection = None |
|
|
|
|
|
if config.use_batch_norm: |
|
self.batch_norm = nn.BatchNorm1d(config.latent_dim) |
|
else: |
|
self.batch_norm = None |
|
|
|
|
|
if config.dropout_rate > 0: |
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
else: |
|
self.dropout = None |
|
|
|
def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: |
|
""" |
|
Forward pass through recurrent encoder. |
|
|
|
Args: |
|
x: Input tensor of shape (batch_size, seq_len, input_dim) |
|
lengths: Sequence lengths for packed sequences (optional) |
|
|
|
Returns: |
|
Encoded representation or tuple for VAE |
|
""" |
|
batch_size, seq_len, _ = x.shape |
|
|
|
|
|
if self.config.is_denoising and self.training: |
|
noise = torch.randn_like(x) * self.config.noise_factor |
|
x = x + noise |
|
|
|
|
|
if lengths is not None: |
|
x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) |
|
|
|
|
|
if self.config.rnn_type == "lstm": |
|
output, (hidden, cell) = self.rnn(x) |
|
else: |
|
output, hidden = self.rnn(x) |
|
cell = None |
|
|
|
|
|
if lengths is not None: |
|
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) |
|
|
|
|
|
if self.config.bidirectional: |
|
|
|
hidden = hidden.view(self.config.num_layers, 2, batch_size, self.config.latent_dim) |
|
hidden = hidden[-1] |
|
hidden = hidden.transpose(0, 1).contiguous().view(batch_size, -1) |
|
|
|
|
|
if self.projection: |
|
hidden = self.projection(hidden) |
|
else: |
|
hidden = hidden[-1] |
|
|
|
|
|
if self.batch_norm: |
|
hidden = self.batch_norm(hidden) |
|
|
|
|
|
if self.dropout and self.training: |
|
hidden = self.dropout(hidden) |
|
|
|
|
|
if self.config.is_variational: |
|
|
|
mu = hidden[:, :self.config.latent_dim // 2] |
|
logvar = hidden[:, self.config.latent_dim // 2:] |
|
|
|
|
|
if self.training: |
|
std = torch.exp(0.5 * logvar) |
|
eps = torch.randn_like(std) |
|
z = mu + eps * std |
|
else: |
|
z = mu |
|
|
|
return z, mu, logvar |
|
else: |
|
return hidden |
|
|
|
|
|
class RecurrentDecoder(nn.Module): |
|
"""Recurrent decoder for sequence data.""" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__() |
|
self.config = config |
|
|
|
|
|
if config.rnn_type == "lstm": |
|
rnn_class = nn.LSTM |
|
elif config.rnn_type == "gru": |
|
rnn_class = nn.GRU |
|
elif config.rnn_type == "rnn": |
|
rnn_class = nn.RNN |
|
else: |
|
raise ValueError(f"Unknown RNN type: {config.rnn_type}") |
|
|
|
|
|
self.rnn = rnn_class( |
|
input_size=config.latent_dim, |
|
hidden_size=config.latent_dim, |
|
num_layers=config.num_layers, |
|
batch_first=True, |
|
dropout=config.dropout_rate if config.num_layers > 1 else 0, |
|
bidirectional=False |
|
) |
|
|
|
|
|
self.output_projection = nn.Linear(config.latent_dim, config.input_dim) |
|
|
|
|
|
if config.use_batch_norm: |
|
self.batch_norm = nn.BatchNorm1d(config.latent_dim) |
|
else: |
|
self.batch_norm = None |
|
|
|
|
|
if config.dropout_rate > 0: |
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
else: |
|
self.dropout = None |
|
|
|
def forward(self, z: torch.Tensor, target_length: int, target_sequence: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
""" |
|
Forward pass through recurrent decoder. |
|
|
|
Args: |
|
z: Latent representation of shape (batch_size, latent_dim) |
|
target_length: Length of sequence to generate |
|
target_sequence: Target sequence for teacher forcing (optional) |
|
|
|
Returns: |
|
Decoded sequence of shape (batch_size, seq_len, input_dim) |
|
""" |
|
batch_size = z.size(0) |
|
device = z.device |
|
|
|
|
|
if self.config.rnn_type == "lstm": |
|
h_0 = z.unsqueeze(0).repeat(self.config.num_layers, 1, 1) |
|
c_0 = torch.zeros_like(h_0) |
|
hidden = (h_0, c_0) |
|
else: |
|
hidden = z.unsqueeze(0).repeat(self.config.num_layers, 1, 1) |
|
|
|
outputs = [] |
|
|
|
|
|
current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) |
|
|
|
for t in range(target_length): |
|
|
|
use_teacher_forcing = (target_sequence is not None and |
|
self.training and |
|
random.random() < self.config.teacher_forcing_ratio) |
|
|
|
if use_teacher_forcing and t > 0: |
|
|
|
current_input = target_sequence[:, t-1:t, :] |
|
|
|
if current_input.size(-1) != self.config.latent_dim: |
|
current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) |
|
|
|
|
|
if self.config.rnn_type == "lstm": |
|
output, hidden = self.rnn(current_input, hidden) |
|
else: |
|
output, hidden = self.rnn(current_input, hidden) |
|
|
|
|
|
output_flat = output.squeeze(1) |
|
|
|
if self.batch_norm: |
|
output_flat = self.batch_norm(output_flat) |
|
|
|
if self.dropout and self.training: |
|
output_flat = self.dropout(output_flat) |
|
|
|
|
|
step_output = self.output_projection(output_flat) |
|
outputs.append(step_output.unsqueeze(1)) |
|
|
|
|
|
if not use_teacher_forcing: |
|
|
|
current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) |
|
|
|
|
|
return torch.cat(outputs, dim=1) |
|
|
|
|
|
class AutoencoderModel(PreTrainedModel): |
|
""" |
|
The bare Autoencoder Model transformer outputting raw hidden-states without any specific head on top. |
|
|
|
This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
|
|
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the |
|
PyTorch documentation for all matter related to general usage and behavior. |
|
""" |
|
|
|
config_class = AutoencoderConfig |
|
base_model_prefix = "autoencoder" |
|
supports_gradient_checkpointing = False |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
if config.has_preprocessing: |
|
self.preprocessor = LearnablePreprocessor(config) |
|
else: |
|
self.preprocessor = None |
|
|
|
|
|
if config.is_recurrent: |
|
self.encoder = RecurrentEncoder(config) |
|
self.decoder = RecurrentDecoder(config) |
|
else: |
|
self.encoder = AutoencoderEncoder(config) |
|
self.decoder = AutoencoderDecoder(config) |
|
|
|
|
|
if config.tie_weights: |
|
self._tie_weights() |
|
|
|
|
|
self.post_init() |
|
|
|
def _tie_weights(self): |
|
"""Tie encoder and decoder weights (transpose relationship).""" |
|
|
|
pass |
|
|
|
def get_input_embeddings(self): |
|
"""Get input embeddings (not applicable for basic autoencoder).""" |
|
return None |
|
|
|
def set_input_embeddings(self, value): |
|
"""Set input embeddings (not applicable for basic autoencoder).""" |
|
pass |
|
|
|
def forward( |
|
self, |
|
input_values: torch.Tensor, |
|
sequence_lengths: Optional[torch.Tensor] = None, |
|
target_length: Optional[int] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], AutoencoderOutput]: |
|
""" |
|
Forward pass through the autoencoder. |
|
|
|
Args: |
|
input_values (torch.Tensor): Input tensor. Shape depends on autoencoder type: |
|
- Standard: (batch_size, input_dim) |
|
- Recurrent: (batch_size, seq_len, input_dim) |
|
sequence_lengths (torch.Tensor, optional): Sequence lengths for recurrent AE. |
|
target_length (int, optional): Target sequence length for recurrent decoder. |
|
output_hidden_states (bool, optional): Whether to return hidden states. |
|
return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple. |
|
|
|
Returns: |
|
AutoencoderOutput or tuple: The model outputs. |
|
""" |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
preprocessing_loss = torch.tensor(0.0, device=input_values.device) |
|
if self.preprocessor is not None: |
|
input_values, preprocessing_loss = self.preprocessor(input_values, inverse=False) |
|
|
|
|
|
if self.config.is_recurrent: |
|
|
|
if sequence_lengths is not None: |
|
encoder_output = self.encoder(input_values, sequence_lengths) |
|
else: |
|
encoder_output = self.encoder(input_values) |
|
|
|
if self.config.is_variational: |
|
latent, mu, logvar = encoder_output |
|
self._mu = mu |
|
self._logvar = logvar |
|
else: |
|
latent = encoder_output |
|
self._mu = None |
|
self._logvar = None |
|
|
|
|
|
if target_length is None: |
|
if self.config.sequence_length is not None: |
|
target_length = self.config.sequence_length |
|
else: |
|
target_length = input_values.size(1) |
|
|
|
|
|
reconstructed = self.decoder(latent, target_length, input_values if self.training else None) |
|
else: |
|
|
|
encoder_output = self.encoder(input_values) |
|
|
|
if self.config.is_variational: |
|
latent, mu, logvar = encoder_output |
|
self._mu = mu |
|
self._logvar = logvar |
|
else: |
|
latent = encoder_output |
|
self._mu = None |
|
self._logvar = None |
|
|
|
|
|
reconstructed = self.decoder(latent) |
|
|
|
|
|
if self.preprocessor is not None and self.config.learn_inverse_preprocessing: |
|
reconstructed, inverse_loss = self.preprocessor(reconstructed, inverse=True) |
|
preprocessing_loss += inverse_loss |
|
|
|
hidden_states = None |
|
if output_hidden_states: |
|
if self.config.is_variational: |
|
hidden_states = (latent, mu, logvar) |
|
else: |
|
hidden_states = (latent,) |
|
|
|
if not return_dict: |
|
return tuple(v for v in [latent, reconstructed, hidden_states] if v is not None) |
|
|
|
return AutoencoderOutput( |
|
last_hidden_state=latent, |
|
reconstructed=reconstructed, |
|
hidden_states=hidden_states, |
|
preprocessing_loss=preprocessing_loss, |
|
) |
|
|
|
|
|
class AutoencoderForReconstruction(PreTrainedModel): |
|
""" |
|
Autoencoder Model with a reconstruction head on top for reconstruction tasks. |
|
|
|
This model inherits from PreTrainedModel and adds a reconstruction loss calculation. |
|
""" |
|
|
|
config_class = AutoencoderConfig |
|
base_model_prefix = "autoencoder" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
self.autoencoder = AutoencoderModel(config) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
"""Get input embeddings.""" |
|
return self.autoencoder.get_input_embeddings() |
|
|
|
def set_input_embeddings(self, value): |
|
"""Set input embeddings.""" |
|
self.autoencoder.set_input_embeddings(value) |
|
|
|
def _compute_reconstruction_loss( |
|
self, |
|
reconstructed: torch.Tensor, |
|
target: torch.Tensor |
|
) -> torch.Tensor: |
|
"""Compute reconstruction loss based on the configured loss type.""" |
|
if self.config.reconstruction_loss == "mse": |
|
return F.mse_loss(reconstructed, target, reduction="mean") |
|
elif self.config.reconstruction_loss == "bce": |
|
return F.binary_cross_entropy_with_logits(reconstructed, target, reduction="mean") |
|
elif self.config.reconstruction_loss == "l1": |
|
return F.l1_loss(reconstructed, target, reduction="mean") |
|
elif self.config.reconstruction_loss == "huber": |
|
return F.huber_loss(reconstructed, target, reduction="mean") |
|
elif self.config.reconstruction_loss == "smooth_l1": |
|
return F.smooth_l1_loss(reconstructed, target, reduction="mean") |
|
elif self.config.reconstruction_loss == "kl_div": |
|
return F.kl_div(F.log_softmax(reconstructed, dim=-1), F.softmax(target, dim=-1), reduction="mean") |
|
elif self.config.reconstruction_loss == "cosine": |
|
return 1 - F.cosine_similarity(reconstructed, target, dim=-1).mean() |
|
elif self.config.reconstruction_loss == "focal": |
|
return self._focal_loss(reconstructed, target) |
|
elif self.config.reconstruction_loss == "dice": |
|
return self._dice_loss(reconstructed, target) |
|
elif self.config.reconstruction_loss == "tversky": |
|
return self._tversky_loss(reconstructed, target) |
|
elif self.config.reconstruction_loss == "ssim": |
|
return self._ssim_loss(reconstructed, target) |
|
elif self.config.reconstruction_loss == "perceptual": |
|
return self._perceptual_loss(reconstructed, target) |
|
else: |
|
raise ValueError(f"Unknown reconstruction loss: {self.config.reconstruction_loss}") |
|
|
|
def _focal_loss(self, pred: torch.Tensor, target: torch.Tensor, alpha: float = 1.0, gamma: float = 2.0) -> torch.Tensor: |
|
"""Compute focal loss for handling class imbalance.""" |
|
ce_loss = F.mse_loss(pred, target, reduction="none") |
|
pt = torch.exp(-ce_loss) |
|
focal_loss = alpha * (1 - pt) ** gamma * ce_loss |
|
return focal_loss.mean() |
|
|
|
def _dice_loss(self, pred: torch.Tensor, target: torch.Tensor, smooth: float = 1e-6) -> torch.Tensor: |
|
"""Compute Dice loss for segmentation-like tasks.""" |
|
pred_flat = pred.view(-1) |
|
target_flat = target.view(-1) |
|
intersection = (pred_flat * target_flat).sum() |
|
dice = (2.0 * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth) |
|
return 1 - dice |
|
|
|
def _tversky_loss(self, pred: torch.Tensor, target: torch.Tensor, alpha: float = 0.7, beta: float = 0.3, smooth: float = 1e-6) -> torch.Tensor: |
|
"""Compute Tversky loss, a generalization of Dice loss.""" |
|
pred_flat = pred.view(-1) |
|
target_flat = target.view(-1) |
|
true_pos = (pred_flat * target_flat).sum() |
|
false_neg = (target_flat * (1 - pred_flat)).sum() |
|
false_pos = ((1 - target_flat) * pred_flat).sum() |
|
tversky = (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth) |
|
return 1 - tversky |
|
|
|
def _ssim_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
|
"""Compute SSIM-based loss (simplified version).""" |
|
|
|
mu1 = pred.mean(dim=-1, keepdim=True) |
|
mu2 = target.mean(dim=-1, keepdim=True) |
|
sigma1_sq = ((pred - mu1) ** 2).mean(dim=-1, keepdim=True) |
|
sigma2_sq = ((target - mu2) ** 2).mean(dim=-1, keepdim=True) |
|
sigma12 = ((pred - mu1) * (target - mu2)).mean(dim=-1, keepdim=True) |
|
|
|
c1, c2 = 0.01, 0.03 |
|
ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / ((mu1**2 + mu2**2 + c1) * (sigma1_sq + sigma2_sq + c2)) |
|
return 1 - ssim.mean() |
|
|
|
def _perceptual_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
|
"""Compute perceptual loss (simplified version using feature differences).""" |
|
|
|
pred_norm = F.normalize(pred, p=2, dim=-1) |
|
target_norm = F.normalize(target, p=2, dim=-1) |
|
return F.mse_loss(pred_norm, target_norm) |
|
|
|
def forward( |
|
self, |
|
input_values: torch.Tensor, |
|
labels: Optional[torch.Tensor] = None, |
|
sequence_lengths: Optional[torch.Tensor] = None, |
|
target_length: Optional[int] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], AutoencoderForReconstructionOutput]: |
|
""" |
|
Forward pass with reconstruction loss calculation. |
|
|
|
Args: |
|
input_values (torch.Tensor): Input tensor. Shape depends on autoencoder type: |
|
- Standard: (batch_size, input_dim) |
|
- Recurrent: (batch_size, seq_len, input_dim) |
|
labels (torch.Tensor, optional): Target tensor for reconstruction. If None, uses input_values. |
|
sequence_lengths (torch.Tensor, optional): Sequence lengths for recurrent AE. |
|
target_length (int, optional): Target sequence length for recurrent decoder. |
|
output_hidden_states (bool, optional): Whether to return hidden states. |
|
return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple. |
|
|
|
Returns: |
|
AutoencoderForReconstructionOutput or tuple: The model outputs including loss. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if labels is None: |
|
labels = input_values |
|
|
|
|
|
outputs = self.autoencoder( |
|
input_values=input_values, |
|
sequence_lengths=sequence_lengths, |
|
target_length=target_length, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=True, |
|
) |
|
|
|
reconstructed = outputs.reconstructed |
|
latent = outputs.last_hidden_state |
|
hidden_states = outputs.hidden_states |
|
|
|
|
|
recon_loss = self._compute_reconstruction_loss(reconstructed, labels) |
|
|
|
|
|
total_loss = recon_loss |
|
|
|
|
|
if hasattr(outputs, 'preprocessing_loss') and outputs.preprocessing_loss is not None: |
|
total_loss += outputs.preprocessing_loss |
|
|
|
if self.config.is_variational and hasattr(self.autoencoder, '_mu') and self.autoencoder._mu is not None: |
|
|
|
kl_loss = -0.5 * torch.sum(1 + self.autoencoder._logvar - self.autoencoder._mu.pow(2) - self.autoencoder._logvar.exp()) |
|
kl_loss = kl_loss / (self.autoencoder._mu.size(0) * self.autoencoder._mu.size(1)) |
|
total_loss = recon_loss + self.config.beta * kl_loss |
|
|
|
elif self.config.is_sparse: |
|
|
|
latent = outputs.last_hidden_state |
|
sparsity_loss = torch.mean(torch.abs(latent)) |
|
total_loss = recon_loss + 0.1 * sparsity_loss |
|
|
|
elif self.config.is_contractive: |
|
|
|
latent = outputs.last_hidden_state |
|
latent.retain_grad() |
|
if latent.grad is not None: |
|
contractive_loss = torch.sum(latent.grad ** 2) |
|
total_loss = recon_loss + 0.1 * contractive_loss |
|
|
|
loss = total_loss |
|
|
|
if not return_dict: |
|
output = (reconstructed, latent) |
|
if hidden_states is not None: |
|
output = output + (hidden_states,) |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return AutoencoderForReconstructionOutput( |
|
loss=loss, |
|
reconstructed=reconstructed, |
|
last_hidden_state=latent, |
|
hidden_states=hidden_states, |
|
preprocessing_loss=outputs.preprocessing_loss if hasattr(outputs, 'preprocessing_loss') else None, |
|
) |
|
|