import torch import torch.nn as nn class AudioEncoder(nn.Module): """ Simple audio encoder using 1D CNN. """ def __init__(self, input_dim, hidden_dim, output_dim): super().__init__() self.cnn = nn.Sequential( nn.Conv1d(input_dim, hidden_dim, kernel_size=3, padding=1), nn.ReLU(), nn.Conv1d(hidden_dim, output_dim, kernel_size=3, padding=1), nn.AdaptiveAvgPool1d(1) ) def forward(self, x): # x: (batch, channels, seq_len) out = self.cnn(x) return out.squeeze(-1)