Stockai / audio_encoder.py
rmanzo28's picture
Upload 4 files
2134603 verified
raw
history blame
580 Bytes
import torch
import torch.nn as nn
class AudioEncoder(nn.Module):
"""
Simple audio encoder using 1D CNN.
"""
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.cnn = nn.Sequential(
nn.Conv1d(input_dim, hidden_dim, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv1d(hidden_dim, output_dim, kernel_size=3, padding=1),
nn.AdaptiveAvgPool1d(1)
)
def forward(self, x):
# x: (batch, channels, seq_len)
out = self.cnn(x)
return out.squeeze(-1)