|
|
"""Shared MoE Audio Projector. |
|
|
|
|
|
A simplified MoE projector combining the best ideas: |
|
|
- Shared expert: Always-on baseline processing (from GLM4) |
|
|
- Zero-initialized router: Learns specialization naturally (from Qwen3) |
|
|
- Simple top-k softmax: No grouping complexity (from Mixtral) |
|
|
- Renormalized weights: Top-k weights sum to 1 |
|
|
|
|
|
Architecture: |
|
|
Output = SharedExpert(x) + TopKRoutedExperts(x) |
|
|
|
|
|
The shared expert ensures every audio token gets consistent baseline |
|
|
processing, while routed experts can specialize for different patterns |
|
|
(e.g., vowels vs consonants, silence vs speech). |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class SharedExpert(nn.Module): |
|
|
"""Shared expert MLP that processes all tokens.""" |
|
|
|
|
|
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): |
|
|
super().__init__() |
|
|
self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False) |
|
|
self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False) |
|
|
self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False) |
|
|
self.act = nn.SiLU() |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
|
|
|
|
class SwiGLUExpert(nn.Module): |
|
|
"""Single SwiGLU expert MLP.""" |
|
|
|
|
|
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): |
|
|
super().__init__() |
|
|
self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False) |
|
|
self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False) |
|
|
self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False) |
|
|
self.act = nn.SiLU() |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
|
|
|
|
class RoutedExperts(nn.Module): |
|
|
""" |
|
|
Sparse routed experts using token dispatch. |
|
|
|
|
|
For each expert, gathers assigned tokens, processes them, then scatters back. |
|
|
Memory-efficient: O(num_tokens * hidden_dim) instead of |
|
|
O(num_tokens * num_experts * hidden_dim * input_dim). |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, num_experts: int, top_k: int, input_dim: int, hidden_dim: int, output_dim: int |
|
|
): |
|
|
super().__init__() |
|
|
self.num_experts = num_experts |
|
|
self.top_k = top_k |
|
|
self.output_dim = output_dim |
|
|
|
|
|
|
|
|
self.experts = nn.ModuleList([ |
|
|
SwiGLUExpert(input_dim, hidden_dim, output_dim) |
|
|
for _ in range(num_experts) |
|
|
]) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
top_k_indices: torch.Tensor, |
|
|
top_k_weights: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Token dispatch approach - memory efficient. |
|
|
|
|
|
Args: |
|
|
hidden_states: [num_tokens, input_dim] |
|
|
top_k_indices: [num_tokens, top_k] |
|
|
top_k_weights: [num_tokens, top_k] |
|
|
|
|
|
Returns: |
|
|
output: [num_tokens, output_dim] |
|
|
""" |
|
|
num_tokens = hidden_states.shape[0] |
|
|
device = hidden_states.device |
|
|
dtype = hidden_states.dtype |
|
|
|
|
|
|
|
|
output = torch.zeros(num_tokens, self.output_dim, device=device, dtype=dtype) |
|
|
|
|
|
|
|
|
for expert_idx, expert in enumerate(self.experts): |
|
|
|
|
|
|
|
|
expert_mask = top_k_indices == expert_idx |
|
|
|
|
|
if not expert_mask.any(): |
|
|
continue |
|
|
|
|
|
|
|
|
token_indices, slot_indices = torch.where(expert_mask) |
|
|
|
|
|
|
|
|
expert_input = hidden_states[token_indices] |
|
|
|
|
|
|
|
|
expert_output = expert(expert_input) |
|
|
|
|
|
|
|
|
weights = top_k_weights[token_indices, slot_indices] |
|
|
|
|
|
|
|
|
weighted_output = expert_output * weights.unsqueeze(-1) |
|
|
|
|
|
|
|
|
output.index_add_(0, token_indices, weighted_output) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class SharedMoEBlock(nn.Module): |
|
|
"""MoE block with shared expert + sparse routed experts.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int, |
|
|
hidden_dim: int, |
|
|
output_dim: int, |
|
|
num_experts: int = 4, |
|
|
top_k: int = 2, |
|
|
): |
|
|
super().__init__() |
|
|
self.num_experts = num_experts |
|
|
self.top_k = top_k |
|
|
|
|
|
|
|
|
self.router = nn.Linear(input_dim, num_experts, bias=False) |
|
|
nn.init.zeros_(self.router.weight) |
|
|
|
|
|
|
|
|
self.shared_expert = SharedExpert(input_dim, hidden_dim, output_dim) |
|
|
|
|
|
|
|
|
self.routed_experts = RoutedExperts( |
|
|
num_experts, self.top_k, input_dim, hidden_dim, output_dim |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.last_router_logits = None |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
batch_size, seq_len, dim = hidden_states.shape |
|
|
|
|
|
|
|
|
shared_out = self.shared_expert(hidden_states) |
|
|
|
|
|
|
|
|
flat_hidden = hidden_states.view(-1, dim) |
|
|
router_logits = self.router(flat_hidden) |
|
|
self.last_router_logits = router_logits |
|
|
|
|
|
|
|
|
router_probs = F.softmax(router_logits.float(), dim=-1) |
|
|
top_k_weights, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1) |
|
|
top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True) |
|
|
top_k_weights = top_k_weights.to(hidden_states.dtype) |
|
|
|
|
|
|
|
|
routed_out = self.routed_experts(flat_hidden, top_k_indices, top_k_weights) |
|
|
routed_out = routed_out.view(batch_size, seq_len, -1) |
|
|
|
|
|
|
|
|
return shared_out + routed_out |
|
|
|
|
|
|
|
|
def load_balancing_loss(router_logits: torch.Tensor, num_experts: int, top_k: int) -> torch.Tensor: |
|
|
"""Auxiliary loss to encourage balanced expert usage.""" |
|
|
if router_logits is None: |
|
|
return torch.tensor(0.0) |
|
|
|
|
|
probs = F.softmax(router_logits.float(), dim=-1) |
|
|
_, selected = torch.topk(probs, top_k, dim=-1) |
|
|
|
|
|
|
|
|
expert_mask = F.one_hot(selected, num_experts).float() |
|
|
tokens_per_expert = expert_mask.mean(dim=(0, 1)) |
|
|
|
|
|
|
|
|
prob_per_expert = probs.mean(dim=0) |
|
|
|
|
|
|
|
|
return (tokens_per_expert * prob_per_expert).sum() * num_experts |
|
|
|
|
|
|
|
|
def z_loss(router_logits: torch.Tensor) -> torch.Tensor: |
|
|
"""Z-loss to prevent router logits from growing too large. |
|
|
|
|
|
From DeepSeek/Switch Transformer: penalizes large logits to keep |
|
|
softmax in its "soft" regime where gradients flow properly. |
|
|
""" |
|
|
if router_logits is None: |
|
|
return torch.tensor(0.0) |
|
|
|
|
|
|
|
|
return torch.logsumexp(router_logits.float(), dim=-1).square().mean() |
|
|
|
|
|
|
|
|
class SharedMoEAudioProjector(nn.Module): |
|
|
"""Shared MoE Audio Projector. |
|
|
|
|
|
Combines a shared expert (always-on) with sparse routed experts. |
|
|
Uses zero-initialized router for natural specialization learning. |
|
|
|
|
|
Config options: |
|
|
- num_experts: Number of routed experts (default: 4) |
|
|
- num_experts_per_tok: Top-k routing (default: 2) |
|
|
- router_aux_loss_coef: Load balancing loss weight (default: 0.01) |
|
|
- router_z_loss_coef: Z-loss weight to prevent large logits (default: 0.001) |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.k = getattr(config, "projector_pool_stride", 4) |
|
|
|
|
|
|
|
|
self.encoder_dim = config.encoder_dim |
|
|
in_dim = self.encoder_dim * self.k |
|
|
out_dim = config.llm_dim |
|
|
|
|
|
hidden_dim = getattr(config, "projector_hidden_dim", None) or in_dim |
|
|
|
|
|
|
|
|
self.num_experts = getattr(config, "num_experts", 4) |
|
|
self.top_k = getattr(config, "num_experts_per_tok", 2) |
|
|
self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.01) |
|
|
self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001) |
|
|
|
|
|
|
|
|
self.moe = SharedMoEBlock(in_dim, hidden_dim, out_dim, self.num_experts, self.top_k) |
|
|
|
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
with torch.no_grad(): |
|
|
|
|
|
in_dim = self.encoder_dim * self.k |
|
|
std = 1.0 / (in_dim ** 0.5) |
|
|
|
|
|
|
|
|
down_proj_std = std / 2.0 |
|
|
|
|
|
|
|
|
nn.init.normal_(self.moe.shared_expert.gate_proj.weight, std=std) |
|
|
nn.init.normal_(self.moe.shared_expert.up_proj.weight, std=std) |
|
|
nn.init.normal_(self.moe.shared_expert.down_proj.weight, std=down_proj_std) |
|
|
|
|
|
|
|
|
for expert in self.moe.routed_experts.experts: |
|
|
nn.init.normal_(expert.gate_proj.weight, std=std) |
|
|
nn.init.normal_(expert.up_proj.weight, std=std) |
|
|
nn.init.zeros_(expert.down_proj.weight) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor: |
|
|
batch_size, seq_len, dim = x.size() |
|
|
|
|
|
|
|
|
target_dtype = self.moe.shared_expert.gate_proj.weight.dtype |
|
|
if x.dtype != target_dtype: |
|
|
x = x.to(target_dtype) |
|
|
|
|
|
|
|
|
if seq_len % self.k: |
|
|
x = F.pad(x, (0, 0, 0, self.k - seq_len % self.k)) |
|
|
if attention_mask is not None: |
|
|
attention_mask = F.pad(attention_mask, (0, self.k - seq_len % self.k), value=0) |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
pooled_mask = F.max_pool1d(attention_mask.float().unsqueeze(1), self.k, self.k) |
|
|
self.last_attention_mask = pooled_mask.squeeze(1).bool() |
|
|
else: |
|
|
self.last_attention_mask = None |
|
|
|
|
|
|
|
|
x = x.view(batch_size, -1, dim * self.k) |
|
|
|
|
|
|
|
|
x = self.moe(x) |
|
|
|
|
|
return x |
|
|
|
|
|
def get_aux_loss(self) -> torch.Tensor: |
|
|
"""Get auxiliary losses (call after forward). |
|
|
|
|
|
Combines: |
|
|
- Load balancing loss: encourages balanced expert usage |
|
|
- Z-loss: prevents router logits from growing too large |
|
|
""" |
|
|
router_logits = self.moe.last_router_logits |
|
|
if router_logits is None: |
|
|
return torch.tensor(0.0, device=self.moe.router.weight.device) |
|
|
|
|
|
|
|
|
attention_mask = getattr(self, "last_attention_mask", None) |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
flat_mask = attention_mask.view(-1) |
|
|
|
|
|
if flat_mask.any(): |
|
|
active_logits = router_logits[flat_mask] |
|
|
else: |
|
|
|
|
|
return torch.tensor(0.0, device=router_logits.device) |
|
|
else: |
|
|
active_logits = router_logits |
|
|
|
|
|
balance_loss = load_balancing_loss(active_logits, self.num_experts, self.top_k) |
|
|
z = z_loss(active_logits) |
|
|
|
|
|
return self.aux_loss_coef * balance_loss + self.z_loss_coef * z |
|
|
|