File size: 12,113 Bytes
10b2bb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
"""Shared MoE Audio Projector.

A simplified MoE projector combining the best ideas:
- Shared expert: Always-on baseline processing (from GLM4)
- Zero-initialized router: Learns specialization naturally (from Qwen3)
- Simple top-k softmax: No grouping complexity (from Mixtral)
- Renormalized weights: Top-k weights sum to 1

Architecture:
    Output = SharedExpert(x) + TopKRoutedExperts(x)

The shared expert ensures every audio token gets consistent baseline
processing, while routed experts can specialize for different patterns
(e.g., vowels vs consonants, silence vs speech).
"""

import torch
import torch.nn as nn
import torch.nn.functional as F  # noqa: N812


class SharedExpert(nn.Module):
    """Shared expert MLP that processes all tokens."""

    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False)
        self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
        self.act = nn.SiLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))


class SwiGLUExpert(nn.Module):
    """Single SwiGLU expert MLP."""

    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False)
        self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
        self.act = nn.SiLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))


class RoutedExperts(nn.Module):
    """
    Sparse routed experts using token dispatch.

    For each expert, gathers assigned tokens, processes them, then scatters back.
    Memory-efficient: O(num_tokens * hidden_dim) instead of
    O(num_tokens * num_experts * hidden_dim * input_dim).
    """

    def __init__(
        self, num_experts: int, top_k: int, input_dim: int, hidden_dim: int, output_dim: int
    ):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.output_dim = output_dim

        # ModuleList of expert MLPs
        self.experts = nn.ModuleList([
            SwiGLUExpert(input_dim, hidden_dim, output_dim)
            for _ in range(num_experts)
        ])

    def forward(
        self,
        hidden_states: torch.Tensor,
        top_k_indices: torch.Tensor,
        top_k_weights: torch.Tensor,
    ) -> torch.Tensor:
        """
        Token dispatch approach - memory efficient.

        Args:
            hidden_states: [num_tokens, input_dim]
            top_k_indices: [num_tokens, top_k]
            top_k_weights: [num_tokens, top_k]

        Returns:
            output: [num_tokens, output_dim]
        """
        num_tokens = hidden_states.shape[0]
        device = hidden_states.device
        dtype = hidden_states.dtype

        # Output accumulator
        output = torch.zeros(num_tokens, self.output_dim, device=device, dtype=dtype)

        # Process each expert
        for expert_idx, expert in enumerate(self.experts):
            # Find which (token, slot) pairs use this expert
            # top_k_indices: [N, K], we want all positions where value == expert_idx
            expert_mask = top_k_indices == expert_idx  # [N, K]

            if not expert_mask.any():
                continue

            # Get token indices and slot indices where this expert is selected
            token_indices, slot_indices = torch.where(expert_mask)

            # Gather the tokens for this expert
            expert_input = hidden_states[token_indices]  # [num_selected, input_dim]

            # Process through expert
            expert_output = expert(expert_input)  # [num_selected, output_dim]

            # Get weights for these tokens at these slots
            weights = top_k_weights[token_indices, slot_indices]  # [num_selected]

            # Weighted output
            weighted_output = expert_output * weights.unsqueeze(-1)

            # Scatter-add back to output
            output.index_add_(0, token_indices, weighted_output)

        return output


class SharedMoEBlock(nn.Module):
    """MoE block with shared expert + sparse routed experts."""

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        num_experts: int = 4,
        top_k: int = 2,
    ):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k

        # Router: zero-initialized for natural learning
        self.router = nn.Linear(input_dim, num_experts, bias=False)
        nn.init.zeros_(self.router.weight)

        # Shared expert (always active)
        self.shared_expert = SharedExpert(input_dim, hidden_dim, output_dim)

        # Routed experts (sparse)
        self.routed_experts = RoutedExperts(
            num_experts, self.top_k, input_dim, hidden_dim, output_dim
        )


        # For auxiliary loss
        self.last_router_logits = None

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, dim = hidden_states.shape

        # Shared expert output (all tokens)
        shared_out = self.shared_expert(hidden_states)

        # Routing
        flat_hidden = hidden_states.view(-1, dim)
        router_logits = self.router(flat_hidden)
        self.last_router_logits = router_logits

        # Softmax -> top-k -> renormalize
        router_probs = F.softmax(router_logits.float(), dim=-1)
        top_k_weights, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
        top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
        top_k_weights = top_k_weights.to(hidden_states.dtype)

        # Routed expert output
        routed_out = self.routed_experts(flat_hidden, top_k_indices, top_k_weights)
        routed_out = routed_out.view(batch_size, seq_len, -1)

        # Combine: shared expert baseline + routed experts (grow in via zero-init down_proj)
        return shared_out + routed_out


def load_balancing_loss(router_logits: torch.Tensor, num_experts: int, top_k: int) -> torch.Tensor:
    """Auxiliary loss to encourage balanced expert usage."""
    if router_logits is None:
        return torch.tensor(0.0)

    probs = F.softmax(router_logits.float(), dim=-1)
    _, selected = torch.topk(probs, top_k, dim=-1)

    # Fraction of tokens per expert
    expert_mask = F.one_hot(selected, num_experts).float()
    tokens_per_expert = expert_mask.mean(dim=(0, 1))

    # Average probability per expert
    prob_per_expert = probs.mean(dim=0)

    # Balance loss
    return (tokens_per_expert * prob_per_expert).sum() * num_experts


def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
    """Z-loss to prevent router logits from growing too large.

    From DeepSeek/Switch Transformer: penalizes large logits to keep
    softmax in its "soft" regime where gradients flow properly.
    """
    if router_logits is None:
        return torch.tensor(0.0)

    # logsumexp ≈ max(logits), squaring penalizes large values
    return torch.logsumexp(router_logits.float(), dim=-1).square().mean()


class SharedMoEAudioProjector(nn.Module):
    """Shared MoE Audio Projector.

    Combines a shared expert (always-on) with sparse routed experts.
    Uses zero-initialized router for natural specialization learning.

    Config options:
        - num_experts: Number of routed experts (default: 4)
        - num_experts_per_tok: Top-k routing (default: 2)
        - router_aux_loss_coef: Load balancing loss weight (default: 0.01)
        - router_z_loss_coef: Z-loss weight to prevent large logits (default: 0.001)
    """

    def __init__(self, config):
        super().__init__()

        # Temporal downsampling
        self.k = getattr(config, "projector_pool_stride", 4)

        # Dimensions
        self.encoder_dim = config.encoder_dim
        in_dim = self.encoder_dim * self.k
        out_dim = config.llm_dim
        # No expansion - keep hidden dim same as input dim
        hidden_dim = getattr(config, "projector_hidden_dim", None) or in_dim

        # MoE config
        self.num_experts = getattr(config, "num_experts", 4)
        self.top_k = getattr(config, "num_experts_per_tok", 2)
        self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.01)
        self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)

        # Layers
        self.moe = SharedMoEBlock(in_dim, hidden_dim, out_dim, self.num_experts, self.top_k)

        # Init
        self._init_weights()

    def _init_weights(self):
        with torch.no_grad():
            # Xavier init: std = 1/sqrt(fan_in)
            in_dim = self.encoder_dim * self.k
            std = 1.0 / (in_dim ** 0.5)

            # Use a smaller std for the final projection in the shared expert's residual path
            down_proj_std = std / 2.0

            # Shared expert
            nn.init.normal_(self.moe.shared_expert.gate_proj.weight, std=std)
            nn.init.normal_(self.moe.shared_expert.up_proj.weight, std=std)
            nn.init.normal_(self.moe.shared_expert.down_proj.weight, std=down_proj_std)

            # Routed experts - zero init down_proj so they "grow in" from zero
            for expert in self.moe.routed_experts.experts:
                nn.init.normal_(expert.gate_proj.weight, std=std)
                nn.init.normal_(expert.up_proj.weight, std=std)
                nn.init.zeros_(expert.down_proj.weight)

            # Router stays zero-initialized

    def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
        batch_size, seq_len, dim = x.size()

        # Dtype
        target_dtype = self.moe.shared_expert.gate_proj.weight.dtype
        if x.dtype != target_dtype:
            x = x.to(target_dtype)

        # Pad for pooling
        if seq_len % self.k:
            x = F.pad(x, (0, 0, 0, self.k - seq_len % self.k))
            if attention_mask is not None:
                attention_mask = F.pad(attention_mask, (0, self.k - seq_len % self.k), value=0)

        # Store pooled attention mask for aux loss
        if attention_mask is not None:
            # Max-pool the attention mask
            pooled_mask = F.max_pool1d(attention_mask.float().unsqueeze(1), self.k, self.k)
            self.last_attention_mask = pooled_mask.squeeze(1).bool()
        else:
            self.last_attention_mask = None

        # Temporal pooling
        x = x.view(batch_size, -1, dim * self.k)

        # Forward
        x = self.moe(x)

        return x

    def get_aux_loss(self) -> torch.Tensor:
        """Get auxiliary losses (call after forward).

        Combines:
        - Load balancing loss: encourages balanced expert usage
        - Z-loss: prevents router logits from growing too large
        """
        router_logits = self.moe.last_router_logits
        if router_logits is None:
            return torch.tensor(0.0, device=self.moe.router.weight.device)

        # Retrieve the attention mask stored during the forward pass
        attention_mask = getattr(self, "last_attention_mask", None)

        # If a mask exists, filter the logits to only include un-padded tokens
        if attention_mask is not None:
            flat_mask = attention_mask.view(-1)
            # Ensure the mask is not all False, which would create an empty tensor
            if flat_mask.any():
                active_logits = router_logits[flat_mask]
            else:
                # If the mask is all False, there are no tokens to compute loss on
                return torch.tensor(0.0, device=router_logits.device)
        else:
            active_logits = router_logits

        balance_loss = load_balancing_loss(active_logits, self.num_experts, self.top_k)
        z = z_loss(active_logits)

        return self.aux_loss_coef * balance_loss + self.z_loss_coef * z