File size: 5,230 Bytes
c6ab7ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""Residual MLP projector for Whisper → LLM feature space translation.

Philosophy: Whisper features are already information-complete. The projector
learns a nonlinear correction/refinement to align them with the LLM's expected
input distribution, rather than replacing them entirely.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F  # noqa: N812


class ResidualMLP(nn.Module):
    """MLP block with residual connection.

    Output = x + MLP(x)

    At initialization (weights near zero), output ≈ input, providing a stable
    starting point. The network learns to add nonlinear corrections as needed.
    """

    def __init__(self, dim, hidden_dim, dropout=0.0):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return residual + x


class ResidualAudioProjector(nn.Module):
    """Residual MLP projector for audio-to-LLM feature translation.

    Architecture:
        1. Temporal pooling (concatenate k consecutive frames)
        2. Linear projection to LLM dimension
        3. N residual MLP blocks for nonlinear refinement
        4. Final layer norm

    The linear projection handles dimension matching, while residual MLPs
    learn the nonlinear corrections needed to align acoustic features
    with semantic embedding space.
    """

    def __init__(self, config):
        super().__init__()

        # Temporal downsampling factor
        self.k = getattr(config, "projector_pool_stride", 4)

        # Dimensions
        in_dim = config.encoder_dim * self.k  # After concatenating k frames
        out_dim = config.llm_dim
        hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim * 4

        # Number of residual blocks
        self.num_layers = getattr(config, "projector_num_layers", 2)

        dropout_rate = getattr(config, "projector_dropout", 0.0)

        from transformers.models.llama.modeling_llama import LlamaRMSNorm

        # Initial projection: encoder_dim * k → llm_dim
        self.input_proj = nn.Linear(in_dim, out_dim)
        self.ln_input = LlamaRMSNorm(out_dim, eps=1e-6)

        # Residual MLP blocks for nonlinear refinement
        self.layers = nn.ModuleList([
            ResidualMLP(out_dim, hidden_dim, dropout=dropout_rate)
            for _ in range(self.num_layers)
        ])

        # Per-layer norms (applied after each residual block)
        self.layer_norms = nn.ModuleList([
            LlamaRMSNorm(out_dim, eps=1e-6)
            for _ in range(self.num_layers)
        ])

        self.output_dropout = nn.Dropout(dropout_rate)

        # Initialize for stable training
        self._init_weights(config)

    def _init_weights(self, config):
        """Initialize weights for stable residual learning.

        Key insight: Initialize fc2 of each residual block to near-zero
        so that initially output ≈ input (identity function).
        """
        std = getattr(config, "projector_init_std", 0.02)

        with torch.no_grad():
            # Input projection: standard init
            nn.init.normal_(self.input_proj.weight, mean=0.0, std=std)
            if self.input_proj.bias is not None:
                nn.init.zeros_(self.input_proj.bias)

            # Layer norms
            self.ln_input.weight.data.fill_(1.0)
            for ln in self.layer_norms:
                ln.weight.data.fill_(1.0)

            # Residual blocks: small init on output projection
            for layer in self.layers:
                nn.init.normal_(layer.fc1.weight, mean=0.0, std=std)
                # Initialize fc2 smaller so residual starts near identity
                nn.init.normal_(layer.fc2.weight, mean=0.0, std=std * 0.1)
                if layer.fc1.bias is not None:
                    nn.init.zeros_(layer.fc1.bias)
                if layer.fc2.bias is not None:
                    nn.init.zeros_(layer.fc2.bias)

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, encoder_dim] from Whisper encoder

        Returns:
            [batch_size, seq_len // k, llm_dim] projected features
        """
        batch_size, seq_len, dim = x.size()

        # Ensure correct dtype
        target_dtype = self.input_proj.weight.dtype
        if x.dtype != target_dtype:
            x = x.to(target_dtype)

        # Pad sequence to be divisible by k
        remainder = seq_len % self.k
        if remainder:
            pad_len = self.k - remainder
            x = F.pad(x, (0, 0, 0, pad_len))

        # Temporal pooling: concatenate k consecutive frames
        # [B, T, D] → [B, T//k, D*k]
        x = x.contiguous().view(batch_size, -1, dim * self.k)

        # Project to LLM dimension
        x = self.input_proj(x)
        x = self.ln_input(x)

        # Apply residual MLP blocks
        for layer, ln in zip(self.layers, self.layer_norms):
            x = layer(x)
            x = ln(x)

        return self.output_dropout(x)