|
""" |
|
PyTorch Autoencoder model for Hugging Face Transformers. |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import Optional, Tuple, Union, Dict, Any, List |
|
from dataclasses import dataclass |
|
import random |
|
import re |
|
|
|
|
|
try: |
|
from transformers.modeling_utils import PreTrainedModel |
|
except Exception: |
|
|
|
from transformers import PreTrainedModel |
|
|
|
from transformers.modeling_outputs import BaseModelOutput |
|
from transformers.utils import ModelOutput |
|
|
|
try: |
|
from .configuration_autoencoder import AutoencoderConfig |
|
except Exception: |
|
from configuration_autoencoder import AutoencoderConfig |
|
|
|
|
|
try: |
|
from .blocks import ( |
|
BlockFactory, |
|
BlockSequence, |
|
LinearBlockConfig, |
|
AttentionBlockConfig, |
|
RecurrentBlockConfig, |
|
ConvolutionalBlockConfig, |
|
VariationalBlockConfig, |
|
VariationalBlock, |
|
) |
|
except Exception: |
|
from blocks import ( |
|
BlockFactory, |
|
BlockSequence, |
|
LinearBlockConfig, |
|
AttentionBlockConfig, |
|
RecurrentBlockConfig, |
|
ConvolutionalBlockConfig, |
|
VariationalBlockConfig, |
|
VariationalBlock, |
|
) |
|
|
|
|
|
try: |
|
from .utils import _get_activation |
|
except Exception: |
|
from utils import _get_activation |
|
|
|
|
|
try: |
|
from .preprocessing import PreprocessingBlock |
|
except Exception: |
|
from preprocessing import PreprocessingBlock |
|
|
|
|
|
@dataclass |
|
class AutoencoderOutput(ModelOutput): |
|
""" |
|
Output type of AutoencoderModel. |
|
|
|
Args: |
|
last_hidden_state (torch.FloatTensor): The latent representation of the input. |
|
reconstructed (torch.FloatTensor, optional): The reconstructed input. |
|
hidden_states (tuple(torch.FloatTensor), optional): Hidden states of the encoder layers. |
|
attentions (tuple(torch.FloatTensor), optional): Not used in basic autoencoder. |
|
preprocessing_loss (torch.FloatTensor, optional): Loss from learnable preprocessing. |
|
""" |
|
|
|
last_hidden_state: torch.FloatTensor = None |
|
reconstructed: Optional[torch.FloatTensor] = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
preprocessing_loss: Optional[torch.FloatTensor] = None |
|
|
|
|
|
@dataclass |
|
class AutoencoderForReconstructionOutput(ModelOutput): |
|
""" |
|
Output type of AutoencoderForReconstruction. |
|
|
|
Args: |
|
loss (torch.FloatTensor, optional): The reconstruction loss. |
|
reconstructed (torch.FloatTensor): The reconstructed input. |
|
last_hidden_state (torch.FloatTensor): The latent representation. |
|
hidden_states (tuple(torch.FloatTensor), optional): Hidden states of the encoder layers. |
|
preprocessing_loss (torch.FloatTensor, optional): Loss from learnable preprocessing. |
|
""" |
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
reconstructed: torch.FloatTensor = None |
|
last_hidden_state: torch.FloatTensor = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
preprocessing_loss: Optional[torch.FloatTensor] = None |
|
|
|
|
|
class AutoencoderEncoder(nn.Module): |
|
"""Encoder part of the autoencoder.""" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__() |
|
self.config = config |
|
|
|
|
|
layers = [] |
|
input_dim = config.input_dim |
|
|
|
for hidden_dim in config.hidden_dims: |
|
layers.append(nn.Linear(input_dim, hidden_dim)) |
|
|
|
if config.use_batch_norm: |
|
layers.append(nn.BatchNorm1d(hidden_dim)) |
|
|
|
layers.append(self._get_activation(config.activation)) |
|
|
|
if config.dropout_rate > 0: |
|
layers.append(nn.Dropout(config.dropout_rate)) |
|
|
|
input_dim = hidden_dim |
|
|
|
self.encoder = nn.Sequential(*layers) |
|
|
|
|
|
if config.is_variational: |
|
self.fc_mu = nn.Linear(input_dim, config.latent_dim) |
|
self.fc_logvar = nn.Linear(input_dim, config.latent_dim) |
|
else: |
|
|
|
self.fc_out = nn.Linear(input_dim, config.latent_dim) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: |
|
"""Forward pass through encoder.""" |
|
|
|
if self.config.is_denoising and self.training: |
|
noise = torch.randn_like(x) * self.config.noise_factor |
|
x = x + noise |
|
|
|
encoded = self.encoder(x) |
|
|
|
if self.config.is_variational: |
|
|
|
mu = self.fc_mu(encoded) |
|
logvar = self.fc_logvar(encoded) |
|
|
|
|
|
if self.training: |
|
std = torch.exp(0.5 * logvar) |
|
eps = torch.randn_like(std) |
|
z = mu + eps * std |
|
else: |
|
z = mu |
|
|
|
return z, mu, logvar |
|
else: |
|
|
|
latent = self.fc_out(encoded) |
|
|
|
|
|
if self.config.is_sparse and self.training: |
|
|
|
latent = F.relu(latent) |
|
|
|
return latent |
|
|
|
|
|
class AutoencoderDecoder(nn.Module): |
|
"""Decoder part of the autoencoder.""" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__() |
|
self.config = config |
|
|
|
|
|
layers = [] |
|
input_dim = config.latent_dim |
|
decoder_dims = config.decoder_dims + [config.input_dim] |
|
|
|
for i, hidden_dim in enumerate(decoder_dims): |
|
layers.append(nn.Linear(input_dim, hidden_dim)) |
|
|
|
|
|
if i < len(decoder_dims) - 1: |
|
if config.use_batch_norm: |
|
layers.append(nn.BatchNorm1d(hidden_dim)) |
|
|
|
layers.append(_get_activation(config.activation)) |
|
|
|
if config.dropout_rate > 0: |
|
layers.append(nn.Dropout(config.dropout_rate)) |
|
else: |
|
|
|
if config.reconstruction_loss == "bce": |
|
layers.append(nn.Sigmoid()) |
|
|
|
input_dim = hidden_dim |
|
|
|
self.decoder = nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
"""Forward pass through decoder.""" |
|
return self.decoder(x) |
|
|
|
|
|
class RecurrentEncoder(nn.Module): |
|
"""Recurrent encoder for sequence data.""" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__() |
|
self.config = config |
|
|
|
|
|
if config.rnn_type == "lstm": |
|
rnn_class = nn.LSTM |
|
elif config.rnn_type == "gru": |
|
rnn_class = nn.GRU |
|
elif config.rnn_type == "rnn": |
|
rnn_class = nn.RNN |
|
else: |
|
raise ValueError(f"Unknown RNN type: {config.rnn_type}") |
|
|
|
|
|
self.rnn = rnn_class( |
|
input_size=config.input_dim, |
|
hidden_size=config.latent_dim, |
|
num_layers=config.num_layers, |
|
batch_first=True, |
|
dropout=config.dropout_rate if config.num_layers > 1 else 0, |
|
bidirectional=config.bidirectional |
|
) |
|
|
|
|
|
if config.bidirectional: |
|
self.projection = nn.Linear(config.latent_dim * 2, config.latent_dim) |
|
else: |
|
self.projection = None |
|
|
|
|
|
if config.use_batch_norm: |
|
self.batch_norm = nn.BatchNorm1d(config.latent_dim) |
|
else: |
|
self.batch_norm = None |
|
|
|
|
|
if config.dropout_rate > 0: |
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
else: |
|
self.dropout = None |
|
|
|
def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: |
|
""" |
|
Forward pass through recurrent encoder. |
|
|
|
Args: |
|
x: Input tensor of shape (batch_size, seq_len, input_dim) |
|
lengths: Sequence lengths for packed sequences (optional) |
|
|
|
Returns: |
|
Encoded representation or tuple for VAE |
|
""" |
|
batch_size, seq_len, _ = x.shape |
|
|
|
|
|
if self.config.is_denoising and self.training: |
|
noise = torch.randn_like(x) * self.config.noise_factor |
|
x = x + noise |
|
|
|
|
|
if lengths is not None: |
|
x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) |
|
|
|
|
|
if self.config.rnn_type == "lstm": |
|
output, (hidden, cell) = self.rnn(x) |
|
else: |
|
output, hidden = self.rnn(x) |
|
cell = None |
|
|
|
|
|
if lengths is not None: |
|
output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) |
|
|
|
|
|
if self.config.bidirectional: |
|
|
|
hidden = hidden.view(self.config.num_layers, 2, batch_size, self.config.latent_dim) |
|
hidden = hidden[-1] |
|
hidden = hidden.transpose(0, 1).contiguous().view(batch_size, -1) |
|
|
|
|
|
if self.projection: |
|
hidden = self.projection(hidden) |
|
else: |
|
hidden = hidden[-1] |
|
|
|
|
|
if self.batch_norm: |
|
hidden = self.batch_norm(hidden) |
|
|
|
|
|
if self.dropout and self.training: |
|
hidden = self.dropout(hidden) |
|
|
|
|
|
if self.config.is_variational: |
|
|
|
mu = hidden[:, :self.config.latent_dim // 2] |
|
logvar = hidden[:, self.config.latent_dim // 2:] |
|
|
|
|
|
if self.training: |
|
std = torch.exp(0.5 * logvar) |
|
eps = torch.randn_like(std) |
|
z = mu + eps * std |
|
else: |
|
z = mu |
|
|
|
return z, mu, logvar |
|
else: |
|
return hidden |
|
|
|
|
|
class RecurrentDecoder(nn.Module): |
|
"""Recurrent decoder for sequence data.""" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__() |
|
self.config = config |
|
|
|
|
|
if config.rnn_type == "lstm": |
|
rnn_class = nn.LSTM |
|
elif config.rnn_type == "gru": |
|
rnn_class = nn.GRU |
|
elif config.rnn_type == "rnn": |
|
rnn_class = nn.RNN |
|
else: |
|
raise ValueError(f"Unknown RNN type: {config.rnn_type}") |
|
|
|
|
|
self.rnn = rnn_class( |
|
input_size=config.latent_dim, |
|
hidden_size=config.latent_dim, |
|
num_layers=config.num_layers, |
|
batch_first=True, |
|
dropout=config.dropout_rate if config.num_layers > 1 else 0, |
|
bidirectional=False |
|
) |
|
|
|
|
|
self.output_projection = nn.Linear(config.latent_dim, config.input_dim) |
|
|
|
|
|
if config.use_batch_norm: |
|
self.batch_norm = nn.BatchNorm1d(config.latent_dim) |
|
else: |
|
self.batch_norm = None |
|
|
|
|
|
if config.dropout_rate > 0: |
|
self.dropout = nn.Dropout(config.dropout_rate) |
|
else: |
|
self.dropout = None |
|
|
|
def forward(self, z: torch.Tensor, target_length: int, target_sequence: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
""" |
|
Forward pass through recurrent decoder. |
|
|
|
Args: |
|
z: Latent representation of shape (batch_size, latent_dim) |
|
target_length: Length of sequence to generate |
|
target_sequence: Target sequence for teacher forcing (optional) |
|
|
|
Returns: |
|
Decoded sequence of shape (batch_size, seq_len, input_dim) |
|
""" |
|
batch_size = z.size(0) |
|
device = z.device |
|
|
|
|
|
if self.config.rnn_type == "lstm": |
|
h_0 = z.unsqueeze(0).repeat(self.config.num_layers, 1, 1) |
|
c_0 = torch.zeros_like(h_0) |
|
hidden = (h_0, c_0) |
|
else: |
|
hidden = z.unsqueeze(0).repeat(self.config.num_layers, 1, 1) |
|
|
|
outputs = [] |
|
|
|
|
|
current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) |
|
|
|
for t in range(target_length): |
|
|
|
use_teacher_forcing = (target_sequence is not None and |
|
self.training and |
|
random.random() < self.config.teacher_forcing_ratio) |
|
|
|
if use_teacher_forcing and t > 0: |
|
|
|
current_input = target_sequence[:, t-1:t, :] |
|
|
|
if current_input.size(-1) != self.config.latent_dim: |
|
current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) |
|
|
|
|
|
if self.config.rnn_type == "lstm": |
|
output, hidden = self.rnn(current_input, hidden) |
|
else: |
|
output, hidden = self.rnn(current_input, hidden) |
|
|
|
|
|
output_flat = output.squeeze(1) |
|
|
|
if self.batch_norm: |
|
output_flat = self.batch_norm(output_flat) |
|
|
|
if self.dropout and self.training: |
|
output_flat = self.dropout(output_flat) |
|
|
|
|
|
step_output = self.output_projection(output_flat) |
|
outputs.append(step_output.unsqueeze(1)) |
|
|
|
|
|
if not use_teacher_forcing: |
|
|
|
current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) |
|
|
|
|
|
return torch.cat(outputs, dim=1) |
|
|
|
|
|
class AutoencoderModel(PreTrainedModel): |
|
""" |
|
The bare Autoencoder Model transformer outputting raw hidden-states without any specific head on top. |
|
|
|
This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the |
|
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
|
etc.) |
|
|
|
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the |
|
PyTorch documentation for all matter related to general usage and behavior. |
|
""" |
|
|
|
config_class = AutoencoderConfig |
|
base_model_prefix = "autoencoder" |
|
supports_gradient_checkpointing = False |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
if config.has_preprocessing: |
|
self.pre_block = PreprocessingBlock(config, inverse=False) |
|
else: |
|
self.pre_block = None |
|
|
|
|
|
norm = "batch" if config.use_batch_norm else "none" |
|
|
|
def default_linear_sequence(in_dim: int, dims: List[int], activation: str, normalization: str, dropout: float) -> List[LinearBlockConfig]: |
|
cfgs: List[LinearBlockConfig] = [] |
|
prev = in_dim |
|
for h in dims: |
|
cfgs.append( |
|
LinearBlockConfig( |
|
input_dim=prev, |
|
output_dim=h, |
|
activation=activation, |
|
normalization=normalization, |
|
dropout_rate=dropout, |
|
use_residual=False, |
|
) |
|
) |
|
prev = h |
|
return cfgs |
|
|
|
|
|
if getattr(config, "encoder_blocks", None): |
|
enc_cfgs = config.encoder_blocks |
|
|
|
last_out = None |
|
for b in enc_cfgs: |
|
if isinstance(b, dict): |
|
last_out = b.get("output_dim", last_out) |
|
else: |
|
last_out = getattr(b, "output_dim", last_out) |
|
enc_out_dim = last_out or (config.hidden_dims[-1] if config.hidden_dims else config.input_dim) |
|
else: |
|
enc_cfgs = default_linear_sequence(config.input_dim, config.hidden_dims, config.activation, norm, config.dropout_rate) |
|
enc_out_dim = config.hidden_dims[-1] if config.hidden_dims else config.input_dim |
|
base_encoder_seq: BlockSequence = BlockFactory.build_sequence(enc_cfgs) if len(enc_cfgs) > 0 else BlockSequence([]) |
|
|
|
self.encoder_seq = base_encoder_seq |
|
|
|
|
|
if config.is_variational: |
|
self.fc_mu = nn.Linear(enc_out_dim, config.latent_dim) |
|
self.fc_logvar = nn.Linear(enc_out_dim, config.latent_dim) |
|
self.to_latent = None |
|
else: |
|
self.fc_mu = None |
|
self.fc_logvar = None |
|
self.to_latent = nn.Linear(enc_out_dim, config.latent_dim) |
|
|
|
|
|
if getattr(config, "decoder_blocks", None): |
|
dec_cfgs = config.decoder_blocks |
|
else: |
|
dec_dims = config.decoder_dims + [config.input_dim] |
|
dec_cfgs = default_linear_sequence(config.latent_dim, dec_dims, config.activation, norm, config.dropout_rate) |
|
|
|
if len(dec_cfgs) > 0: |
|
last = dec_cfgs[-1] |
|
last.activation = "identity" |
|
last.normalization = "none" |
|
last.dropout_rate = 0.0 |
|
self.decoder_seq: BlockSequence = BlockFactory.build_sequence(dec_cfgs) if len(dec_cfgs) > 0 else BlockSequence([]) |
|
|
|
|
|
if config.tie_weights: |
|
self._tie_weights() |
|
|
|
|
|
self.post_init() |
|
|
|
def _tie_weights(self): |
|
"""Tie encoder and decoder weights (transpose relationship).""" |
|
|
|
pass |
|
|
|
def get_input_embeddings(self): |
|
"""Get input embeddings (not applicable for basic autoencoder).""" |
|
return None |
|
|
|
def set_input_embeddings(self, value): |
|
"""Set input embeddings (not applicable for basic autoencoder).""" |
|
pass |
|
|
|
def forward( |
|
self, |
|
input_values: torch.Tensor, |
|
sequence_lengths: Optional[torch.Tensor] = None, |
|
target_length: Optional[int] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], AutoencoderOutput]: |
|
""" |
|
Forward pass through the autoencoder. |
|
|
|
Args: |
|
input_values (torch.Tensor): Input tensor. Shape depends on autoencoder type: |
|
- Standard: (batch_size, input_dim) |
|
- Recurrent: (batch_size, seq_len, input_dim) |
|
sequence_lengths (torch.Tensor, optional): Sequence lengths for recurrent AE. |
|
target_length (int, optional): Target sequence length for recurrent decoder. |
|
output_hidden_states (bool, optional): Whether to return hidden states. |
|
return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple. |
|
|
|
Returns: |
|
AutoencoderOutput or tuple: The model outputs. |
|
""" |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if self.pre_block is not None: |
|
input_values = self.pre_block(input_values) |
|
preprocessing_loss = torch.tensor(0.0, device=input_values.device) |
|
|
|
|
|
|
|
enc_out = self.encoder_seq(input_values) |
|
|
|
|
|
if self.config.is_variational: |
|
|
|
self._variational = getattr(self, '_variational', None) |
|
if self._variational is None: |
|
self._variational = VariationalBlock(VariationalBlockConfig(input_dim=enc_out.shape[-1], latent_dim=self.config.latent_dim)).to(enc_out.device) |
|
latent = self._variational(enc_out, training=self.training) |
|
self._mu = self._variational._mu |
|
self._logvar = self._variational._logvar |
|
else: |
|
latent = self.to_latent(enc_out) if self.to_latent is not None else enc_out |
|
self._mu, self._logvar = None, None |
|
|
|
|
|
reconstructed = self.decoder_seq(latent) |
|
|
|
|
|
|
|
hidden_states = None |
|
if output_hidden_states: |
|
if self.config.is_variational: |
|
hidden_states = (latent, getattr(self, '_mu', None), getattr(self, '_logvar', None)) |
|
else: |
|
hidden_states = (latent,) |
|
|
|
if not return_dict: |
|
return tuple(v for v in [latent, reconstructed, hidden_states] if v is not None) |
|
|
|
return AutoencoderOutput( |
|
last_hidden_state=latent, |
|
reconstructed=reconstructed, |
|
hidden_states=hidden_states, |
|
preprocessing_loss=preprocessing_loss, |
|
) |
|
|
|
|
|
class AutoencoderForReconstruction(PreTrainedModel): |
|
""" |
|
Autoencoder Model with a reconstruction head on top for reconstruction tasks. |
|
|
|
This model inherits from PreTrainedModel and adds a reconstruction loss calculation. |
|
""" |
|
|
|
config_class = AutoencoderConfig |
|
base_model_prefix = "autoencoder" |
|
|
|
def __init__(self, config: AutoencoderConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
self.autoencoder = AutoencoderModel(config) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
|
|
def get_input_embeddings(self): |
|
"""Get input embeddings.""" |
|
return self.autoencoder.get_input_embeddings() |
|
|
|
def set_input_embeddings(self, value): |
|
"""Set input embeddings.""" |
|
self.autoencoder.set_input_embeddings(value) |
|
|
|
def _compute_reconstruction_loss( |
|
self, |
|
reconstructed: torch.Tensor, |
|
target: torch.Tensor |
|
) -> torch.Tensor: |
|
"""Compute reconstruction loss based on the configured loss type.""" |
|
if self.config.reconstruction_loss == "mse": |
|
return F.mse_loss(reconstructed, target, reduction="mean") |
|
elif self.config.reconstruction_loss == "bce": |
|
return F.binary_cross_entropy_with_logits(reconstructed, target, reduction="mean") |
|
elif self.config.reconstruction_loss == "l1": |
|
return F.l1_loss(reconstructed, target, reduction="mean") |
|
elif self.config.reconstruction_loss == "huber": |
|
return F.huber_loss(reconstructed, target, reduction="mean") |
|
elif self.config.reconstruction_loss == "smooth_l1": |
|
return F.smooth_l1_loss(reconstructed, target, reduction="mean") |
|
elif self.config.reconstruction_loss == "kl_div": |
|
return F.kl_div(F.log_softmax(reconstructed, dim=-1), F.softmax(target, dim=-1), reduction="mean") |
|
elif self.config.reconstruction_loss == "cosine": |
|
return 1 - F.cosine_similarity(reconstructed, target, dim=-1).mean() |
|
elif self.config.reconstruction_loss == "focal": |
|
return self._focal_loss(reconstructed, target) |
|
elif self.config.reconstruction_loss == "dice": |
|
return self._dice_loss(reconstructed, target) |
|
elif self.config.reconstruction_loss == "tversky": |
|
return self._tversky_loss(reconstructed, target) |
|
elif self.config.reconstruction_loss == "ssim": |
|
return self._ssim_loss(reconstructed, target) |
|
elif self.config.reconstruction_loss == "perceptual": |
|
return self._perceptual_loss(reconstructed, target) |
|
else: |
|
raise ValueError(f"Unknown reconstruction loss: {self.config.reconstruction_loss}") |
|
|
|
def _focal_loss(self, pred: torch.Tensor, target: torch.Tensor, alpha: float = 1.0, gamma: float = 2.0) -> torch.Tensor: |
|
"""Compute focal loss for handling class imbalance.""" |
|
ce_loss = F.mse_loss(pred, target, reduction="none") |
|
pt = torch.exp(-ce_loss) |
|
focal_loss = alpha * (1 - pt) ** gamma * ce_loss |
|
return focal_loss.mean() |
|
|
|
def _dice_loss(self, pred: torch.Tensor, target: torch.Tensor, smooth: float = 1e-6) -> torch.Tensor: |
|
"""Compute Dice loss for segmentation-like tasks.""" |
|
pred_flat = pred.view(-1) |
|
target_flat = target.view(-1) |
|
intersection = (pred_flat * target_flat).sum() |
|
dice = (2.0 * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth) |
|
return 1 - dice |
|
|
|
def _tversky_loss(self, pred: torch.Tensor, target: torch.Tensor, alpha: float = 0.7, beta: float = 0.3, smooth: float = 1e-6) -> torch.Tensor: |
|
"""Compute Tversky loss, a generalization of Dice loss.""" |
|
pred_flat = pred.view(-1) |
|
target_flat = target.view(-1) |
|
true_pos = (pred_flat * target_flat).sum() |
|
false_neg = (target_flat * (1 - pred_flat)).sum() |
|
false_pos = ((1 - target_flat) * pred_flat).sum() |
|
tversky = (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth) |
|
return 1 - tversky |
|
|
|
def _ssim_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
|
"""Compute SSIM-based loss (simplified version).""" |
|
|
|
mu1 = pred.mean(dim=-1, keepdim=True) |
|
mu2 = target.mean(dim=-1, keepdim=True) |
|
sigma1_sq = ((pred - mu1) ** 2).mean(dim=-1, keepdim=True) |
|
sigma2_sq = ((target - mu2) ** 2).mean(dim=-1, keepdim=True) |
|
sigma12 = ((pred - mu1) * (target - mu2)).mean(dim=-1, keepdim=True) |
|
|
|
c1, c2 = 0.01, 0.03 |
|
ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / ((mu1**2 + mu2**2 + c1) * (sigma1_sq + sigma2_sq + c2)) |
|
return 1 - ssim.mean() |
|
|
|
def _perceptual_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: |
|
"""Compute perceptual loss (simplified version using feature differences).""" |
|
|
|
pred_norm = F.normalize(pred, p=2, dim=-1) |
|
target_norm = F.normalize(target, p=2, dim=-1) |
|
return F.mse_loss(pred_norm, target_norm) |
|
|
|
def forward( |
|
self, |
|
input_values: torch.Tensor, |
|
labels: Optional[torch.Tensor] = None, |
|
sequence_lengths: Optional[torch.Tensor] = None, |
|
target_length: Optional[int] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], AutoencoderForReconstructionOutput]: |
|
""" |
|
Forward pass with reconstruction loss calculation. |
|
|
|
Args: |
|
input_values (torch.Tensor): Input tensor. Shape depends on autoencoder type: |
|
- Standard: (batch_size, input_dim) |
|
- Recurrent: (batch_size, seq_len, input_dim) |
|
labels (torch.Tensor, optional): Target tensor for reconstruction. If None, uses input_values. |
|
sequence_lengths (torch.Tensor, optional): Sequence lengths for recurrent AE. |
|
target_length (int, optional): Target sequence length for recurrent decoder. |
|
output_hidden_states (bool, optional): Whether to return hidden states. |
|
return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple. |
|
|
|
Returns: |
|
AutoencoderForReconstructionOutput or tuple: The model outputs including loss. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if labels is None: |
|
labels = input_values |
|
|
|
|
|
outputs = self.autoencoder( |
|
input_values=input_values, |
|
sequence_lengths=sequence_lengths, |
|
target_length=target_length, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=True, |
|
) |
|
|
|
reconstructed = outputs.reconstructed |
|
latent = outputs.last_hidden_state |
|
hidden_states = outputs.hidden_states |
|
|
|
|
|
recon_loss = self._compute_reconstruction_loss(reconstructed, labels) |
|
|
|
|
|
total_loss = recon_loss |
|
|
|
|
|
if hasattr(outputs, 'preprocessing_loss') and outputs.preprocessing_loss is not None: |
|
total_loss += outputs.preprocessing_loss |
|
|
|
if self.config.is_variational and hasattr(self.autoencoder, '_mu') and self.autoencoder._mu is not None: |
|
|
|
kl_loss = -0.5 * torch.sum(1 + self.autoencoder._logvar - self.autoencoder._mu.pow(2) - self.autoencoder._logvar.exp()) |
|
kl_loss = kl_loss / (self.autoencoder._mu.size(0) * self.autoencoder._mu.size(1)) |
|
total_loss = recon_loss + self.config.beta * kl_loss |
|
|
|
elif self.config.is_sparse: |
|
|
|
latent = outputs.last_hidden_state |
|
sparsity_loss = torch.mean(torch.abs(latent)) |
|
total_loss = recon_loss + 0.1 * sparsity_loss |
|
|
|
elif self.config.is_contractive: |
|
|
|
latent = outputs.last_hidden_state |
|
latent.retain_grad() |
|
if latent.grad is not None: |
|
contractive_loss = torch.sum(latent.grad ** 2) |
|
total_loss = recon_loss + 0.1 * contractive_loss |
|
|
|
loss = total_loss |
|
|
|
if not return_dict: |
|
output = (reconstructed, latent) |
|
if hidden_states is not None: |
|
output = output + (hidden_states,) |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return AutoencoderForReconstructionOutput( |
|
loss=loss, |
|
reconstructed=reconstructed, |
|
last_hidden_state=latent, |
|
hidden_states=hidden_states, |
|
preprocessing_loss=outputs.preprocessing_loss if hasattr(outputs, 'preprocessing_loss') else None, |
|
) |
|
|