|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm |
|
|
|
|
|
|
|
|
class SimpleAdapter(nn.Module): |
|
|
""" |
|
|
MOSA Section III-B: |
|
|
"consists of two linear layers with a ReLU activation in between, |
|
|
projecting the hidden dimension from 3072 to 4096 and back to 3072." |
|
|
""" |
|
|
|
|
|
def __init__(self, in_features, hidden_features, out_features, dropout=0.0): |
|
|
super().__init__() |
|
|
self.fc1 = nn.Linear(in_features, hidden_features) |
|
|
self.relu = nn.ReLU() |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.fc2 = nn.Linear(hidden_features, out_features) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.fc1(x) |
|
|
x = self.relu(x) |
|
|
x = self.dropout(x) |
|
|
x = self.fc2(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class MoEAudioProjector(nn.Module): |
|
|
""" |
|
|
MOSA-style projector: Mixture of Simple Adapters. |
|
|
|
|
|
From paper (arXiv:2508.18998): |
|
|
- Dense mixture (softmax over ALL experts) instead of sparse Top-K |
|
|
- Simple Linear->ReLU->Linear adapters (3072->4096->3072) |
|
|
- No auxiliary losses - just cross-entropy on transcripts |
|
|
- Conv downsampling: stride 4 total (two conv layers, stride 2 each) |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.encoder_dim = config.encoder_dim |
|
|
self.llm_dim = config.llm_dim |
|
|
|
|
|
|
|
|
self.num_experts = getattr(config, "num_experts", 4) |
|
|
|
|
|
|
|
|
adapter_hidden = getattr(config, "projector_hidden_dim", None) or 4096 |
|
|
|
|
|
|
|
|
self.dropout_rate = getattr(config, "projector_dropout", 0.1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.conv = nn.Sequential( |
|
|
nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1), |
|
|
nn.ReLU(), |
|
|
nn.Conv1d(self.llm_dim, self.llm_dim, kernel_size=3, stride=2, padding=1), |
|
|
nn.ReLU(), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
router_hidden = 512 |
|
|
self.router = nn.Sequential( |
|
|
nn.Linear(self.encoder_dim, router_hidden), |
|
|
nn.ReLU(), |
|
|
nn.Linear(router_hidden, self.num_experts), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.experts = nn.ModuleList( |
|
|
[ |
|
|
SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim, dropout=self.dropout_rate) |
|
|
for _ in range(self.num_experts) |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
self.ln_post = LlamaRMSNorm(self.llm_dim, eps=1e-6) |
|
|
|
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
"""Initialize weights for stable training.""" |
|
|
std = 0.02 |
|
|
with torch.no_grad(): |
|
|
|
|
|
for module in self.conv: |
|
|
if isinstance(module, nn.Conv1d): |
|
|
nn.init.normal_(module.weight, mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
|
|
|
for module in self.router: |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.normal_(module.weight, mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
|
|
|
for expert in self.experts: |
|
|
nn.init.normal_(expert.fc1.weight, mean=0.0, std=std) |
|
|
nn.init.normal_(expert.fc2.weight, mean=0.0, std=std) |
|
|
if expert.fc1.bias is not None: |
|
|
nn.init.zeros_(expert.fc1.bias) |
|
|
if expert.fc2.bias is not None: |
|
|
nn.init.zeros_(expert.fc2.bias) |
|
|
|
|
|
|
|
|
self.ln_post.weight.data.fill_(1.0) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Args: |
|
|
x: [batch_size, seq_len, encoder_dim] from Whisper encoder (1280) |
|
|
|
|
|
Returns: |
|
|
output: [batch_size, seq_len // 4, llm_dim] (3072) |
|
|
""" |
|
|
batch_size, seq_len, _ = x.shape |
|
|
|
|
|
|
|
|
pad_amt = (4 - (seq_len % 4)) % 4 |
|
|
if pad_amt > 0: |
|
|
x = F.pad(x, (0, 0, 0, pad_amt)) |
|
|
seq_len = x.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
h_conv = self.conv(x.permute(0, 2, 1)).permute(0, 2, 1) |
|
|
|
|
|
|
|
|
router_logits = self.router(x) |
|
|
|
|
|
router_logits = router_logits.view(batch_size, seq_len // 4, 4, self.num_experts).mean( |
|
|
dim=2 |
|
|
) |
|
|
|
|
|
routing_weights = F.softmax(router_logits, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
final_out = torch.zeros_like(h_conv) |
|
|
for i, expert in enumerate(self.experts): |
|
|
expert_out = expert(h_conv) |
|
|
expert_weight = routing_weights[:, :, i : i + 1] |
|
|
final_out.add_(expert_out * expert_weight) |
|
|
|
|
|
return self.ln_post(final_out) |
|
|
|
|
|
def get_aux_loss(self) -> torch.Tensor: |
|
|
"""Return auxiliary loss (none for dense MoE - all experts always used).""" |
|
|
return torch.tensor(0.0) |
|
|
|