""" PyTorch Autoencoder model for Hugging Face Transformers. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, Union, Dict, Any, List from dataclasses import dataclass import random import re # Import PreTrainedModel in a way that avoids circular imports in some environments (e.g., Databricks) try: from transformers.modeling_utils import PreTrainedModel except Exception: # Fallback if direct path is unavailable from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutput from transformers.utils import ModelOutput try: from .configuration_autoencoder import AutoencoderConfig # when loaded via HF dynamic module except Exception: from configuration_autoencoder import AutoencoderConfig # local usage # Block-based architecture components try: from .blocks import ( BlockFactory, BlockSequence, LinearBlockConfig, AttentionBlockConfig, RecurrentBlockConfig, ConvolutionalBlockConfig, VariationalBlockConfig, VariationalBlock, ) # when in package except Exception: from blocks import ( BlockFactory, BlockSequence, LinearBlockConfig, AttentionBlockConfig, RecurrentBlockConfig, ConvolutionalBlockConfig, VariationalBlockConfig, VariationalBlock, ) # local usage # Shared utilities try: from .utils import _get_activation except Exception: from utils import _get_activation # Preprocessing components try: from .preprocessing import PreprocessingBlock # when in package except Exception: from preprocessing import PreprocessingBlock # local usage @dataclass class AutoencoderOutput(ModelOutput): """ Output type of AutoencoderModel. Args: last_hidden_state (torch.FloatTensor): The latent representation of the input. reconstructed (torch.FloatTensor, optional): The reconstructed input. hidden_states (tuple(torch.FloatTensor), optional): Hidden states of the encoder layers. attentions (tuple(torch.FloatTensor), optional): Not used in basic autoencoder. preprocessing_loss (torch.FloatTensor, optional): Loss from learnable preprocessing. """ last_hidden_state: torch.FloatTensor = None reconstructed: Optional[torch.FloatTensor] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None preprocessing_loss: Optional[torch.FloatTensor] = None @dataclass class AutoencoderForReconstructionOutput(ModelOutput): """ Output type of AutoencoderForReconstruction. Args: loss (torch.FloatTensor, optional): The reconstruction loss. reconstructed (torch.FloatTensor): The reconstructed input. last_hidden_state (torch.FloatTensor): The latent representation. hidden_states (tuple(torch.FloatTensor), optional): Hidden states of the encoder layers. preprocessing_loss (torch.FloatTensor, optional): Loss from learnable preprocessing. """ loss: Optional[torch.FloatTensor] = None reconstructed: torch.FloatTensor = None last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None preprocessing_loss: Optional[torch.FloatTensor] = None class AutoencoderEncoder(nn.Module): """Encoder part of the autoencoder.""" def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config # Build encoder layers layers = [] input_dim = config.input_dim for hidden_dim in config.hidden_dims: layers.append(nn.Linear(input_dim, hidden_dim)) if config.use_batch_norm: layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(self._get_activation(config.activation)) if config.dropout_rate > 0: layers.append(nn.Dropout(config.dropout_rate)) input_dim = hidden_dim self.encoder = nn.Sequential(*layers) # For variational autoencoders, we need separate layers for mean and log variance if config.is_variational: self.fc_mu = nn.Linear(input_dim, config.latent_dim) self.fc_logvar = nn.Linear(input_dim, config.latent_dim) else: # Standard encoder output self.fc_out = nn.Linear(input_dim, config.latent_dim) def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Forward pass through encoder.""" # Add noise for denoising autoencoders if self.config.is_denoising and self.training: noise = torch.randn_like(x) * self.config.noise_factor x = x + noise encoded = self.encoder(x) if self.config.is_variational: # Variational autoencoder: return mean, log variance, and sampled latent mu = self.fc_mu(encoded) logvar = self.fc_logvar(encoded) # Reparameterization trick if self.training: std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) z = mu + eps * std else: z = mu # Use mean during inference return z, mu, logvar else: # Standard autoencoder latent = self.fc_out(encoded) # Add sparsity constraint for sparse autoencoders if self.config.is_sparse and self.training: # Apply L1 regularization to encourage sparsity latent = F.relu(latent) # Ensure non-negative activations return latent class AutoencoderDecoder(nn.Module): """Decoder part of the autoencoder.""" def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config # Build decoder layers (reverse of encoder) layers = [] input_dim = config.latent_dim decoder_dims = config.decoder_dims + [config.input_dim] for i, hidden_dim in enumerate(decoder_dims): layers.append(nn.Linear(input_dim, hidden_dim)) # Don't add batch norm, activation, or dropout to the final layer if i < len(decoder_dims) - 1: if config.use_batch_norm: layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(_get_activation(config.activation)) if config.dropout_rate > 0: layers.append(nn.Dropout(config.dropout_rate)) else: # Final layer - add appropriate activation based on reconstruction loss if config.reconstruction_loss == "bce": layers.append(nn.Sigmoid()) input_dim = hidden_dim self.decoder = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through decoder.""" return self.decoder(x) class RecurrentEncoder(nn.Module): """Recurrent encoder for sequence data.""" def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config # Get RNN class if config.rnn_type == "lstm": rnn_class = nn.LSTM elif config.rnn_type == "gru": rnn_class = nn.GRU elif config.rnn_type == "rnn": rnn_class = nn.RNN else: raise ValueError(f"Unknown RNN type: {config.rnn_type}") # Create RNN layers self.rnn = rnn_class( input_size=config.input_dim, hidden_size=config.latent_dim, num_layers=config.num_layers, batch_first=True, dropout=config.dropout_rate if config.num_layers > 1 else 0, bidirectional=config.bidirectional ) # Projection layer for bidirectional RNN if config.bidirectional: self.projection = nn.Linear(config.latent_dim * 2, config.latent_dim) else: self.projection = None # Batch normalization if config.use_batch_norm: self.batch_norm = nn.BatchNorm1d(config.latent_dim) else: self.batch_norm = None # Dropout if config.dropout_rate > 0: self.dropout = nn.Dropout(config.dropout_rate) else: self.dropout = None def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ Forward pass through recurrent encoder. Args: x: Input tensor of shape (batch_size, seq_len, input_dim) lengths: Sequence lengths for packed sequences (optional) Returns: Encoded representation or tuple for VAE """ batch_size, seq_len, _ = x.shape # Add noise for denoising autoencoders if self.config.is_denoising and self.training: noise = torch.randn_like(x) * self.config.noise_factor x = x + noise # Pack sequences if lengths provided if lengths is not None: x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) # RNN forward pass if self.config.rnn_type == "lstm": output, (hidden, cell) = self.rnn(x) else: output, hidden = self.rnn(x) cell = None # Unpack if necessary if lengths is not None: output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) # Use last hidden state as encoding if self.config.bidirectional: # Concatenate forward and backward hidden states hidden = hidden.view(self.config.num_layers, 2, batch_size, self.config.latent_dim) hidden = hidden[-1] # Take last layer hidden = hidden.transpose(0, 1).contiguous().view(batch_size, -1) # Concatenate directions # Project to latent dimension if self.projection: hidden = self.projection(hidden) else: hidden = hidden[-1] # Take last layer # Apply batch normalization if self.batch_norm: hidden = self.batch_norm(hidden) # Apply dropout if self.dropout and self.training: hidden = self.dropout(hidden) # Handle variational encoding if self.config.is_variational: # Split hidden into mean and log variance mu = hidden[:, :self.config.latent_dim // 2] logvar = hidden[:, self.config.latent_dim // 2:] # Reparameterization trick if self.training: std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) z = mu + eps * std else: z = mu return z, mu, logvar else: return hidden class RecurrentDecoder(nn.Module): """Recurrent decoder for sequence data.""" def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config # Get RNN class if config.rnn_type == "lstm": rnn_class = nn.LSTM elif config.rnn_type == "gru": rnn_class = nn.GRU elif config.rnn_type == "rnn": rnn_class = nn.RNN else: raise ValueError(f"Unknown RNN type: {config.rnn_type}") # Create RNN layers self.rnn = rnn_class( input_size=config.latent_dim, hidden_size=config.latent_dim, num_layers=config.num_layers, batch_first=True, dropout=config.dropout_rate if config.num_layers > 1 else 0, bidirectional=False # Decoder is always unidirectional ) # Output projection self.output_projection = nn.Linear(config.latent_dim, config.input_dim) # Batch normalization if config.use_batch_norm: self.batch_norm = nn.BatchNorm1d(config.latent_dim) else: self.batch_norm = None # Dropout if config.dropout_rate > 0: self.dropout = nn.Dropout(config.dropout_rate) else: self.dropout = None def forward(self, z: torch.Tensor, target_length: int, target_sequence: Optional[torch.Tensor] = None) -> torch.Tensor: """ Forward pass through recurrent decoder. Args: z: Latent representation of shape (batch_size, latent_dim) target_length: Length of sequence to generate target_sequence: Target sequence for teacher forcing (optional) Returns: Decoded sequence of shape (batch_size, seq_len, input_dim) """ batch_size = z.size(0) device = z.device # Initialize hidden state with latent representation if self.config.rnn_type == "lstm": h_0 = z.unsqueeze(0).repeat(self.config.num_layers, 1, 1) c_0 = torch.zeros_like(h_0) hidden = (h_0, c_0) else: hidden = z.unsqueeze(0).repeat(self.config.num_layers, 1, 1) outputs = [] # Initialize input (can be learned or zero) current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) for t in range(target_length): # Teacher forcing decision use_teacher_forcing = (target_sequence is not None and self.training and random.random() < self.config.teacher_forcing_ratio) if use_teacher_forcing and t > 0: # Use previous target as input current_input = target_sequence[:, t-1:t, :] # Project to latent dimension if needed if current_input.size(-1) != self.config.latent_dim: current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) # RNN forward step if self.config.rnn_type == "lstm": output, hidden = self.rnn(current_input, hidden) else: output, hidden = self.rnn(current_input, hidden) # Apply batch normalization and dropout output_flat = output.squeeze(1) # Remove sequence dimension if self.batch_norm: output_flat = self.batch_norm(output_flat) if self.dropout and self.training: output_flat = self.dropout(output_flat) # Project to output dimension step_output = self.output_projection(output_flat) outputs.append(step_output.unsqueeze(1)) # Use output as next input (for non-teacher forcing) if not use_teacher_forcing: # Project output back to latent dimension for next step current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) # Concatenate all outputs return torch.cat(outputs, dim=1) class AutoencoderModel(PreTrainedModel): """ The bare Autoencoder Model transformer outputting raw hidden-states without any specific head on top. This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. """ config_class = AutoencoderConfig base_model_prefix = "autoencoder" supports_gradient_checkpointing = False def __init__(self, config: AutoencoderConfig): super().__init__(config) self.config = config # Initialize learnable preprocessing as a single forward block only if config.has_preprocessing: self.pre_block = PreprocessingBlock(config, inverse=False) else: self.pre_block = None # Build block-based encoder/decoder sequences (breaking change refactor) norm = "batch" if config.use_batch_norm else "none" def default_linear_sequence(in_dim: int, dims: List[int], activation: str, normalization: str, dropout: float) -> List[LinearBlockConfig]: cfgs: List[LinearBlockConfig] = [] prev = in_dim for h in dims: cfgs.append( LinearBlockConfig( input_dim=prev, output_dim=h, activation=activation, normalization=normalization, dropout_rate=dropout, use_residual=False, ) ) prev = h return cfgs # Encoder: use explicit block list if provided, else hidden_dims default if getattr(config, "encoder_blocks", None): enc_cfgs = config.encoder_blocks # Compute enc_out_dim from last block's output_dim if linear/conv, else assume input_dim last_out = None for b in enc_cfgs: if isinstance(b, dict): last_out = b.get("output_dim", last_out) else: last_out = getattr(b, "output_dim", last_out) enc_out_dim = last_out or (config.hidden_dims[-1] if config.hidden_dims else config.input_dim) else: enc_cfgs = default_linear_sequence(config.input_dim, config.hidden_dims, config.activation, norm, config.dropout_rate) enc_out_dim = config.hidden_dims[-1] if config.hidden_dims else config.input_dim base_encoder_seq: BlockSequence = BlockFactory.build_sequence(enc_cfgs) if len(enc_cfgs) > 0 else BlockSequence([]) # Do not inject pre_block into encoder sequence; apply it explicitly in forward self.encoder_seq = base_encoder_seq # Project to latent if config.is_variational: self.fc_mu = nn.Linear(enc_out_dim, config.latent_dim) self.fc_logvar = nn.Linear(enc_out_dim, config.latent_dim) self.to_latent = None else: self.fc_mu = None self.fc_logvar = None self.to_latent = nn.Linear(enc_out_dim, config.latent_dim) # Decoder: use explicit block list if provided, else default MLP back to input if getattr(config, "decoder_blocks", None): dec_cfgs = config.decoder_blocks else: dec_dims = config.decoder_dims + [config.input_dim] dec_cfgs = default_linear_sequence(config.latent_dim, dec_dims, config.activation, norm, config.dropout_rate) # For final projection to input_dim: identity activation and no norm/dropout if len(dec_cfgs) > 0: last = dec_cfgs[-1] last.activation = "identity" last.normalization = "none" last.dropout_rate = 0.0 self.decoder_seq: BlockSequence = BlockFactory.build_sequence(dec_cfgs) if len(dec_cfgs) > 0 else BlockSequence([]) # Tie weights if specified (no-op for now) if config.tie_weights: self._tie_weights() # Initialize weights self.post_init() def _tie_weights(self): """Tie encoder and decoder weights (transpose relationship).""" # This is a simplified weight tying - in practice, you might want more sophisticated tying pass def get_input_embeddings(self): """Get input embeddings (not applicable for basic autoencoder).""" return None def set_input_embeddings(self, value): """Set input embeddings (not applicable for basic autoencoder).""" pass def forward( self, input_values: torch.Tensor, sequence_lengths: Optional[torch.Tensor] = None, target_length: Optional[int] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], AutoencoderOutput]: """ Forward pass through the autoencoder. Args: input_values (torch.Tensor): Input tensor. Shape depends on autoencoder type: - Standard: (batch_size, input_dim) - Recurrent: (batch_size, seq_len, input_dim) sequence_lengths (torch.Tensor, optional): Sequence lengths for recurrent AE. target_length (int, optional): Target sequence length for recurrent decoder. output_hidden_states (bool, optional): Whether to return hidden states. return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple. Returns: AutoencoderOutput or tuple: The model outputs. """ output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Apply learnable preprocessing via block (forward only) if self.pre_block is not None: input_values = self.pre_block(input_values) preprocessing_loss = torch.tensor(0.0, device=input_values.device) # Block-based forward # Encode through block sequence enc_out = self.encoder_seq(input_values) # Sample or project to latent if self.config.is_variational: # Use VariationalBlock to encapsulate VAE behavior self._variational = getattr(self, '_variational', None) if self._variational is None: self._variational = VariationalBlock(VariationalBlockConfig(input_dim=enc_out.shape[-1], latent_dim=self.config.latent_dim)).to(enc_out.device) latent = self._variational(enc_out, training=self.training) self._mu = self._variational._mu self._logvar = self._variational._logvar else: latent = self.to_latent(enc_out) if self.to_latent is not None else enc_out self._mu, self._logvar = None, None # Decode back to input space reconstructed = self.decoder_seq(latent) hidden_states = None if output_hidden_states: if self.config.is_variational: hidden_states = (latent, getattr(self, '_mu', None), getattr(self, '_logvar', None)) else: hidden_states = (latent,) if not return_dict: return tuple(v for v in [latent, reconstructed, hidden_states] if v is not None) return AutoencoderOutput( last_hidden_state=latent, reconstructed=reconstructed, hidden_states=hidden_states, preprocessing_loss=preprocessing_loss, ) class AutoencoderForReconstruction(PreTrainedModel): """ Autoencoder Model with a reconstruction head on top for reconstruction tasks. This model inherits from PreTrainedModel and adds a reconstruction loss calculation. """ config_class = AutoencoderConfig base_model_prefix = "autoencoder" def __init__(self, config: AutoencoderConfig): super().__init__(config) self.config = config # Initialize the base autoencoder model self.autoencoder = AutoencoderModel(config) # Initialize weights self.post_init() def get_input_embeddings(self): """Get input embeddings.""" return self.autoencoder.get_input_embeddings() def set_input_embeddings(self, value): """Set input embeddings.""" self.autoencoder.set_input_embeddings(value) def _compute_reconstruction_loss( self, reconstructed: torch.Tensor, target: torch.Tensor ) -> torch.Tensor: """Compute reconstruction loss based on the configured loss type.""" if self.config.reconstruction_loss == "mse": return F.mse_loss(reconstructed, target, reduction="mean") elif self.config.reconstruction_loss == "bce": return F.binary_cross_entropy_with_logits(reconstructed, target, reduction="mean") elif self.config.reconstruction_loss == "l1": return F.l1_loss(reconstructed, target, reduction="mean") elif self.config.reconstruction_loss == "huber": return F.huber_loss(reconstructed, target, reduction="mean") elif self.config.reconstruction_loss == "smooth_l1": return F.smooth_l1_loss(reconstructed, target, reduction="mean") elif self.config.reconstruction_loss == "kl_div": return F.kl_div(F.log_softmax(reconstructed, dim=-1), F.softmax(target, dim=-1), reduction="mean") elif self.config.reconstruction_loss == "cosine": return 1 - F.cosine_similarity(reconstructed, target, dim=-1).mean() elif self.config.reconstruction_loss == "focal": return self._focal_loss(reconstructed, target) elif self.config.reconstruction_loss == "dice": return self._dice_loss(reconstructed, target) elif self.config.reconstruction_loss == "tversky": return self._tversky_loss(reconstructed, target) elif self.config.reconstruction_loss == "ssim": return self._ssim_loss(reconstructed, target) elif self.config.reconstruction_loss == "perceptual": return self._perceptual_loss(reconstructed, target) else: raise ValueError(f"Unknown reconstruction loss: {self.config.reconstruction_loss}") def _focal_loss(self, pred: torch.Tensor, target: torch.Tensor, alpha: float = 1.0, gamma: float = 2.0) -> torch.Tensor: """Compute focal loss for handling class imbalance.""" ce_loss = F.mse_loss(pred, target, reduction="none") pt = torch.exp(-ce_loss) focal_loss = alpha * (1 - pt) ** gamma * ce_loss return focal_loss.mean() def _dice_loss(self, pred: torch.Tensor, target: torch.Tensor, smooth: float = 1e-6) -> torch.Tensor: """Compute Dice loss for segmentation-like tasks.""" pred_flat = pred.view(-1) target_flat = target.view(-1) intersection = (pred_flat * target_flat).sum() dice = (2.0 * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth) return 1 - dice def _tversky_loss(self, pred: torch.Tensor, target: torch.Tensor, alpha: float = 0.7, beta: float = 0.3, smooth: float = 1e-6) -> torch.Tensor: """Compute Tversky loss, a generalization of Dice loss.""" pred_flat = pred.view(-1) target_flat = target.view(-1) true_pos = (pred_flat * target_flat).sum() false_neg = (target_flat * (1 - pred_flat)).sum() false_pos = ((1 - target_flat) * pred_flat).sum() tversky = (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth) return 1 - tversky def _ssim_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Compute SSIM-based loss (simplified version).""" # Simplified SSIM for 1D data mu1 = pred.mean(dim=-1, keepdim=True) mu2 = target.mean(dim=-1, keepdim=True) sigma1_sq = ((pred - mu1) ** 2).mean(dim=-1, keepdim=True) sigma2_sq = ((target - mu2) ** 2).mean(dim=-1, keepdim=True) sigma12 = ((pred - mu1) * (target - mu2)).mean(dim=-1, keepdim=True) c1, c2 = 0.01, 0.03 ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / ((mu1**2 + mu2**2 + c1) * (sigma1_sq + sigma2_sq + c2)) return 1 - ssim.mean() def _perceptual_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Compute perceptual loss (simplified version using feature differences).""" # For simplicity, use L2 loss on normalized features pred_norm = F.normalize(pred, p=2, dim=-1) target_norm = F.normalize(target, p=2, dim=-1) return F.mse_loss(pred_norm, target_norm) def forward( self, input_values: torch.Tensor, labels: Optional[torch.Tensor] = None, sequence_lengths: Optional[torch.Tensor] = None, target_length: Optional[int] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], AutoencoderForReconstructionOutput]: """ Forward pass with reconstruction loss calculation. Args: input_values (torch.Tensor): Input tensor. Shape depends on autoencoder type: - Standard: (batch_size, input_dim) - Recurrent: (batch_size, seq_len, input_dim) labels (torch.Tensor, optional): Target tensor for reconstruction. If None, uses input_values. sequence_lengths (torch.Tensor, optional): Sequence lengths for recurrent AE. target_length (int, optional): Target sequence length for recurrent decoder. output_hidden_states (bool, optional): Whether to return hidden states. return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple. Returns: AutoencoderForReconstructionOutput or tuple: The model outputs including loss. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # If no labels provided, use input as target (standard autoencoder) if labels is None: labels = input_values # Forward pass through autoencoder outputs = self.autoencoder( input_values=input_values, sequence_lengths=sequence_lengths, target_length=target_length, output_hidden_states=output_hidden_states, return_dict=True, ) reconstructed = outputs.reconstructed latent = outputs.last_hidden_state hidden_states = outputs.hidden_states # Compute reconstruction loss recon_loss = self._compute_reconstruction_loss(reconstructed, labels) # Add regularization losses based on autoencoder type total_loss = recon_loss # Add preprocessing loss if available if hasattr(outputs, 'preprocessing_loss') and outputs.preprocessing_loss is not None: total_loss += outputs.preprocessing_loss if self.config.is_variational and hasattr(self.autoencoder, '_mu') and self.autoencoder._mu is not None: # KL divergence loss for variational autoencoders kl_loss = -0.5 * torch.sum(1 + self.autoencoder._logvar - self.autoencoder._mu.pow(2) - self.autoencoder._logvar.exp()) kl_loss = kl_loss / (self.autoencoder._mu.size(0) * self.autoencoder._mu.size(1)) # Normalize by batch size and latent dim total_loss = recon_loss + self.config.beta * kl_loss elif self.config.is_sparse: # Sparsity loss for sparse autoencoders latent = outputs.last_hidden_state sparsity_loss = torch.mean(torch.abs(latent)) # L1 sparsity total_loss = recon_loss + 0.1 * sparsity_loss # Sparsity weight elif self.config.is_contractive: # Contractive loss - penalize large gradients of hidden representation w.r.t. input latent = outputs.last_hidden_state latent.retain_grad() if latent.grad is not None: contractive_loss = torch.sum(latent.grad ** 2) total_loss = recon_loss + 0.1 * contractive_loss loss = total_loss if not return_dict: output = (reconstructed, latent) if hidden_states is not None: output = output + (hidden_states,) return ((loss,) + output) if loss is not None else output return AutoencoderForReconstructionOutput( loss=loss, reconstructed=reconstructed, last_hidden_state=latent, hidden_states=hidden_states, preprocessing_loss=outputs.preprocessing_loss if hasattr(outputs, 'preprocessing_loss') else None, )