vectominist commited on
Commit
aab2435
·
1 Parent(s): 453c9e0

upload model and code

Browse files
Files changed (6) hide show
  1. config.json +39 -0
  2. configuration_usad.py +66 -0
  3. model.safetensors +3 -0
  4. modeling_usad.py +19 -0
  5. usad_model.py +207 -0
  6. usad_modules.py +764 -0
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "USADModel"
4
+ ],
5
+ "attention_dropout_p": 0.1,
6
+ "attention_type": "mhsa",
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_usad.USADConfig",
9
+ "AutoModel": "modeling_usad.USADModel"
10
+ },
11
+ "conv_dropout_p": 0.1,
12
+ "conv_expansion_factor": 2,
13
+ "conv_kernel_size": 31,
14
+ "conv_pos": true,
15
+ "conv_pos_depth": 5,
16
+ "conv_pos_groups": 16,
17
+ "conv_pos_width": 95,
18
+ "conv_subsample_channels": 64,
19
+ "conv_subsample_rate": 2,
20
+ "encoder_dim": 1024,
21
+ "feed_forward_dropout_p": 0.1,
22
+ "feed_forward_expansion_factor": 4,
23
+ "half_step_residual": true,
24
+ "input_dim": 128,
25
+ "input_dropout_p": 0.0,
26
+ "mamba_bidirectional": false,
27
+ "mamba_d_conv": 4,
28
+ "mamba_d_state": 16,
29
+ "mamba_expand": 2,
30
+ "model_type": "usad",
31
+ "num_attention_heads": 16,
32
+ "num_layers": 24,
33
+ "subsample_normalization": true,
34
+ "torch_dtype": "float32",
35
+ "transformer_style": true,
36
+ "transformers_version": "4.52.4",
37
+ "use_framewise_subsample": true,
38
+ "use_patchwise_subsample": false
39
+ }
configuration_usad.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class USADConfig(PretrainedConfig):
5
+ model_type = "usad"
6
+
7
+ def __init__(
8
+ self,
9
+ encoder_dim: int = 384,
10
+ num_layers: int = 12,
11
+ attention_type: str = "mhsa",
12
+ num_attention_heads: int = 6,
13
+ mamba_d_state: int = 16,
14
+ mamba_d_conv: int = 4,
15
+ mamba_expand: int = 2,
16
+ mamba_bidirectional: bool = False,
17
+ feed_forward_expansion_factor: int = 4,
18
+ conv_expansion_factor: int = 2,
19
+ feed_forward_dropout_p: float = 0.1,
20
+ attention_dropout_p: float = 0.1,
21
+ conv_dropout_p: float = 0.1,
22
+ conv_kernel_size: int = 31,
23
+ half_step_residual: bool = True,
24
+ transformer_style: bool = True,
25
+ use_framewise_subsample: bool = True,
26
+ use_patchwise_subsample: bool = False,
27
+ conv_subsample_channels: int = 64,
28
+ conv_subsample_rate: int = 2,
29
+ input_dim: int = 128,
30
+ input_dropout_p: float = 0.0,
31
+ conv_pos: bool = True,
32
+ conv_pos_depth: int = 5,
33
+ conv_pos_width: int = 95,
34
+ conv_pos_groups: int = 16,
35
+ subsample_normalization: bool = True,
36
+ **kwargs,
37
+ ):
38
+ super().__init__(**kwargs)
39
+
40
+ self.encoder_dim = encoder_dim
41
+ self.num_layers = num_layers
42
+ self.attention_type = attention_type
43
+ self.num_attention_heads = num_attention_heads
44
+ self.mamba_d_state = mamba_d_state
45
+ self.mamba_d_conv = mamba_d_conv
46
+ self.mamba_expand = mamba_expand
47
+ self.mamba_bidirectional = mamba_bidirectional
48
+ self.feed_forward_expansion_factor = feed_forward_expansion_factor
49
+ self.conv_expansion_factor = conv_expansion_factor
50
+ self.feed_forward_dropout_p = feed_forward_dropout_p
51
+ self.attention_dropout_p = attention_dropout_p
52
+ self.conv_dropout_p = conv_dropout_p
53
+ self.conv_kernel_size = conv_kernel_size
54
+ self.half_step_residual = half_step_residual
55
+ self.transformer_style = transformer_style
56
+ self.use_framewise_subsample = use_framewise_subsample
57
+ self.use_patchwise_subsample = use_patchwise_subsample
58
+ self.conv_subsample_channels = conv_subsample_channels
59
+ self.conv_subsample_rate = conv_subsample_rate
60
+ self.input_dim = input_dim
61
+ self.input_dropout_p = input_dropout_p
62
+ self.conv_pos = conv_pos
63
+ self.conv_pos_depth = conv_pos_depth
64
+ self.conv_pos_width = conv_pos_width
65
+ self.conv_pos_groups = conv_pos_groups
66
+ self.subsample_normalization = subsample_normalization
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5b8f98da245729082692545783647fdcd2164d0b144456249e9f8944e6e5fd6
3
+ size 1343582744
modeling_usad.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_usad.py
2
+
3
+ from transformers import PreTrainedModel
4
+ from .configuration_usad import USADConfig
5
+ from .usad_model import UsadModel as model
6
+
7
+
8
+ class USADModel(PreTrainedModel):
9
+ config_class = USADConfig
10
+
11
+ def __init__(self, config: USADConfig):
12
+ super().__init__(config)
13
+ self.model = model(config)
14
+
15
+ def forward(self, *args, **kwargs):
16
+ return self.model(*args, **kwargs)
17
+
18
+ def load_audio(self, audio_path):
19
+ return self.model.load_audio(audio_path)
usad_model.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import make_dataclass
2
+
3
+ import torch
4
+ import torchaudio
5
+ from torch import nn
6
+
7
+ from .usad_modules import ConformerEncoder
8
+
9
+ MAX_MEL_LENGTH = 3000 # 30 seconds
10
+
11
+
12
+ @torch.no_grad()
13
+ def wav_to_fbank(
14
+ wavs: torch.Tensor,
15
+ mel_dim: int = 128,
16
+ norm_mean: float = -4.268,
17
+ norm_std: float = 4.569,
18
+ ) -> torch.Tensor:
19
+ """Convert waveform to fbank features.
20
+
21
+ Args:
22
+ wavs (torch.Tensor): (B, T_wav) waveform tensor.
23
+ mel_dim (int, optional): mel dimension. Defaults to 128.
24
+ norm_mean (float, optional):
25
+ mean for normalization. Defaults to -4.268.
26
+ norm_std (float, optional):
27
+ std for normalization. Defaults to 4.569.
28
+
29
+ Returns:
30
+ torch.Tensor: (B, T_mel, mel_dim) fbank features.
31
+ """
32
+ # ref: https://github.com/cwx-worst-one/EAT/tree/main/feature_extract
33
+ dtype = wavs.dtype
34
+ wavs = wavs.to(torch.float32)
35
+ wavs = wavs - wavs.mean(dim=-1, keepdim=True)
36
+ feats = [
37
+ torchaudio.compliance.kaldi.fbank(
38
+ wavs[i : i + 1],
39
+ htk_compat=True,
40
+ sample_frequency=16000,
41
+ use_energy=False,
42
+ window_type="hanning",
43
+ num_mel_bins=mel_dim,
44
+ dither=0.0,
45
+ frame_shift=10,
46
+ ).to(dtype=dtype)
47
+ for i in range(wavs.shape[0])
48
+ ]
49
+
50
+ mels = torch.stack(feats, dim=0)
51
+ mels = (mels - norm_mean) / (norm_std * 2)
52
+
53
+ return mels
54
+
55
+
56
+ class UsadModel(nn.Module):
57
+ def __init__(self, cfg) -> None:
58
+ """Initialize the UsadModel.
59
+ Args:
60
+ cfg: Configuration object containing model parameters.
61
+ """
62
+ super().__init__()
63
+
64
+ self.cfg = cfg
65
+ self.encoder = ConformerEncoder(cfg)
66
+ self.max_mel_length = MAX_MEL_LENGTH
67
+ # NOTE: The max_mel_length is set to 3000,
68
+ # which corresponds to 30 seconds of audio at 100 Hz frame rate.
69
+
70
+ @property
71
+ def sample_rate(self) -> int:
72
+ return 16000 # Hz
73
+
74
+ @property
75
+ def encoder_frame_rate(self) -> int:
76
+ return 50 # Hz
77
+
78
+ @property
79
+ def mel_dim(self) -> int:
80
+ return self.cfg.input_dim
81
+
82
+ @property
83
+ def encoder_dim(self) -> int:
84
+ return self.cfg.encoder_dim
85
+
86
+ @property
87
+ def num_layers(self) -> int:
88
+ return self.cfg.num_layers
89
+
90
+ @property
91
+ def scene_embedding_size(self) -> int:
92
+ return self.cfg.encoder_dim * self.cfg.num_layers
93
+
94
+ @property
95
+ def timestamp_embedding_size(self) -> int:
96
+ return self.cfg.encoder_dim * self.cfg.num_layers
97
+
98
+ @property
99
+ def device(self) -> torch.device:
100
+ """Get the device on which the model is located."""
101
+ return next(self.parameters()).device
102
+
103
+ def set_audio_chunk_size(self, seconds: float = 30.0) -> None:
104
+ """Set the maximum chunk size for feature extraction.
105
+
106
+ Args:
107
+ seconds (float, optional): Chunk size in seconds. Defaults to 30.0.
108
+ """
109
+ assert (
110
+ seconds >= 0.1
111
+ ), f"Chunk size must be greater than 0.1s, got {seconds} seconds."
112
+ self.max_mel_length = int(seconds * 100) # 100 Hz frame rate
113
+
114
+ def load_audio(self, audio_path: str) -> torch.Tensor:
115
+ """Load audio file and return waveform tensor.
116
+ Args:
117
+ audio_path (str): Path to the audio file.
118
+
119
+ Returns:
120
+ torch.Tensor: Waveform tensor of shape (wav_len,).
121
+ """
122
+
123
+ waveform, sr = torchaudio.load(audio_path)
124
+ if sr != self.sample_rate:
125
+ waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)
126
+ if waveform.shape[0] > 1:
127
+ # If stereo, convert to mono by averaging channels
128
+ waveform = waveform.mean(dim=0, keepdim=True)
129
+
130
+ waveform = waveform.squeeze(0) # Remove channel dimension if mono
131
+ return waveform.to(self.device) # Ensure tensor is on the same device
132
+
133
+ def forward(
134
+ self,
135
+ wavs: torch.Tensor,
136
+ norm_mean: float = -4.268,
137
+ norm_std: float = 4.569,
138
+ ) -> dict:
139
+ """Forward pass for the model.
140
+
141
+ Args:
142
+ wavs (torch.Tensor):
143
+ Input waveform tensor of shape (batch_size, wav_len).
144
+ norm_mean (float, optional):
145
+ Mean for normalization. Defaults to -4.268.
146
+ norm_std (float, optional):
147
+ Standard deviation for normalization. Defaults to 4.569.
148
+
149
+ Returns:
150
+ dict: A dictionary containing the model's outputs.
151
+ """
152
+ # wavs: (batch_size, wav_len)
153
+
154
+ mel = wav_to_fbank(wavs, norm_mean=norm_mean, norm_std=norm_std)
155
+ mel = mel[:, : mel.shape[1] - mel.shape[1] % 2]
156
+ if mel.shape[1] <= self.max_mel_length:
157
+ x, x_len, layer_results = self.encoder(mel, return_hidden=True)
158
+
159
+ result = {
160
+ "x": x,
161
+ "mel": mel,
162
+ "hidden_states": layer_results["hidden_states"],
163
+ "ffn": layer_results["ffn_1"],
164
+ }
165
+ return result
166
+
167
+ result = {
168
+ "x": [],
169
+ "mel": mel,
170
+ "hidden_states": [[] for _ in range(self.cfg.num_layers)],
171
+ "ffn": [[] for _ in range(self.cfg.num_layers)],
172
+ }
173
+ for i in range(0, mel.shape[1], self.max_mel_length):
174
+ if mel.shape[1] - i < 10:
175
+ break
176
+
177
+ x, x_len, layer_results = self.encoder(
178
+ mel[:, i : i + self.max_mel_length], return_hidden=True
179
+ )
180
+ result["x"].append(x)
181
+ for j in range(self.cfg.num_layers):
182
+ result["hidden_states"][j].append(layer_results["hidden_states"][j])
183
+ result["ffn"][j].append(layer_results["ffn_1"][j])
184
+
185
+ result["x"] = torch.cat(result["x"], dim=1)
186
+ for j in range(self.cfg.num_layers):
187
+ result["hidden_states"][j] = torch.cat(result["hidden_states"][j], dim=1)
188
+ result["ffn"][j] = torch.cat(result["ffn"][j], dim=1)
189
+
190
+ # result["x"]: model final output (batch_size, seq_len)
191
+ # result["mel"]: mel fbank (batch_size, seq_len * 2, mel_dim)
192
+ # result["hidden_states"]: List of (batch_size, seq_len, encoder_dim)
193
+ # result["ffn"]: List of (batch_size, seq_len, encoder_dim)
194
+ return result
195
+
196
+ @classmethod
197
+ def load_from_fairseq_ckpt(cls, ckpt_path: str):
198
+ checkpoint = torch.load(ckpt_path, weights_only=False)
199
+ config = checkpoint["cfg"]["model"]
200
+ config = make_dataclass("Config", config.keys())(**config)
201
+ model = cls(config)
202
+ state_dict = checkpoint["model"]
203
+ for k in list(state_dict.keys()):
204
+ if not k.startswith("encoder."):
205
+ del state_dict[k]
206
+ model.load_state_dict(state_dict, strict=True)
207
+ return model
usad_modules.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, Soohwan Kim. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import contextlib
16
+ import math
17
+ from collections import defaultdict
18
+ from typing import Dict, List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import nn
23
+
24
+
25
+ class SamePad(nn.Module):
26
+ def __init__(self, kernel_size, causal=False):
27
+ super().__init__()
28
+ if causal:
29
+ self.remove = kernel_size - 1
30
+ else:
31
+ self.remove = 1 if kernel_size % 2 == 0 else 0
32
+
33
+ def forward(self, x):
34
+ if self.remove > 0:
35
+ x = x[:, :, : -self.remove]
36
+ return x
37
+
38
+
39
+ class TransposeLast(nn.Module):
40
+ def __init__(self, deconstruct_idx=None, tranpose_dim=-2):
41
+ super().__init__()
42
+ self.deconstruct_idx = deconstruct_idx
43
+ self.tranpose_dim = tranpose_dim
44
+
45
+ def forward(self, x):
46
+ if self.deconstruct_idx is not None:
47
+ x = x[self.deconstruct_idx]
48
+ return x.transpose(self.tranpose_dim, -1)
49
+
50
+
51
+ class Swish(nn.Module):
52
+ def __init__(self):
53
+ super(Swish, self).__init__()
54
+
55
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
56
+ return inputs * inputs.sigmoid()
57
+
58
+
59
+ class GLU(nn.Module):
60
+ def __init__(self, dim: int) -> None:
61
+ super(GLU, self).__init__()
62
+ self.dim = dim
63
+
64
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
65
+ outputs, gate = inputs.chunk(2, dim=self.dim)
66
+ return outputs * gate.sigmoid()
67
+
68
+
69
+ class ResidualConnectionModule(nn.Module):
70
+ def __init__(
71
+ self,
72
+ module: nn.Module,
73
+ module_factor: float = 1.0,
74
+ input_factor: float = 1.0,
75
+ ):
76
+ super(ResidualConnectionModule, self).__init__()
77
+ self.module = module
78
+ self.module_factor = module_factor
79
+ self.input_factor = input_factor
80
+
81
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
82
+ return (self.module(inputs) * self.module_factor) + (inputs * self.input_factor)
83
+
84
+
85
+ class Linear(nn.Module):
86
+ def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
87
+ super(Linear, self).__init__()
88
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
89
+ nn.init.xavier_uniform_(self.linear.weight)
90
+ if bias:
91
+ nn.init.zeros_(self.linear.bias)
92
+
93
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
94
+ return self.linear(x)
95
+
96
+
97
+ class View(nn.Module):
98
+ def __init__(self, shape: tuple, contiguous: bool = False):
99
+ super(View, self).__init__()
100
+ self.shape = shape
101
+ self.contiguous = contiguous
102
+
103
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
104
+ if self.contiguous:
105
+ x = x.contiguous()
106
+
107
+ return x.view(*self.shape)
108
+
109
+
110
+ class Transpose(nn.Module):
111
+ def __init__(self, shape: tuple):
112
+ super(Transpose, self).__init__()
113
+ self.shape = shape
114
+
115
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
116
+ return x.transpose(*self.shape)
117
+
118
+
119
+ class FeedForwardModule(nn.Module):
120
+ def __init__(
121
+ self,
122
+ encoder_dim: int = 512,
123
+ expansion_factor: int = 4,
124
+ dropout_p: float = 0.1,
125
+ ) -> None:
126
+ super(FeedForwardModule, self).__init__()
127
+ self.sequential = nn.Sequential(
128
+ nn.LayerNorm(encoder_dim),
129
+ Linear(encoder_dim, encoder_dim * expansion_factor, bias=True),
130
+ Swish(),
131
+ nn.Dropout(p=dropout_p),
132
+ Linear(encoder_dim * expansion_factor, encoder_dim, bias=True),
133
+ nn.Dropout(p=dropout_p),
134
+ )
135
+
136
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
137
+ return self.sequential(inputs)
138
+
139
+
140
+ class DepthwiseConv1d(nn.Module):
141
+ def __init__(
142
+ self,
143
+ in_channels: int,
144
+ out_channels: int,
145
+ kernel_size: int,
146
+ stride: int = 1,
147
+ padding: int = 0,
148
+ bias: bool = False,
149
+ ) -> None:
150
+ super(DepthwiseConv1d, self).__init__()
151
+ assert (
152
+ out_channels % in_channels == 0
153
+ ), "out_channels should be constant multiple of in_channels"
154
+ self.conv = nn.Conv1d(
155
+ in_channels=in_channels,
156
+ out_channels=out_channels,
157
+ kernel_size=kernel_size,
158
+ groups=in_channels,
159
+ stride=stride,
160
+ padding=padding,
161
+ bias=bias,
162
+ )
163
+
164
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
165
+ return self.conv(inputs)
166
+
167
+
168
+ class PointwiseConv1d(nn.Module):
169
+ def __init__(
170
+ self,
171
+ in_channels: int,
172
+ out_channels: int,
173
+ stride: int = 1,
174
+ padding: int = 0,
175
+ bias: bool = True,
176
+ ) -> None:
177
+ super(PointwiseConv1d, self).__init__()
178
+ self.conv = nn.Conv1d(
179
+ in_channels=in_channels,
180
+ out_channels=out_channels,
181
+ kernel_size=1,
182
+ stride=stride,
183
+ padding=padding,
184
+ bias=bias,
185
+ )
186
+
187
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
188
+ return self.conv(inputs)
189
+
190
+
191
+ class ConformerConvModule(nn.Module):
192
+ def __init__(
193
+ self,
194
+ in_channels: int,
195
+ kernel_size: int = 31,
196
+ expansion_factor: int = 2,
197
+ dropout_p: float = 0.1,
198
+ ) -> None:
199
+ super(ConformerConvModule, self).__init__()
200
+ assert (
201
+ kernel_size - 1
202
+ ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
203
+ assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"
204
+
205
+ self.sequential = nn.Sequential(
206
+ nn.LayerNorm(in_channels),
207
+ Transpose(shape=(1, 2)),
208
+ PointwiseConv1d(
209
+ in_channels,
210
+ in_channels * expansion_factor,
211
+ stride=1,
212
+ padding=0,
213
+ bias=True,
214
+ ),
215
+ GLU(dim=1),
216
+ DepthwiseConv1d(
217
+ in_channels,
218
+ in_channels,
219
+ kernel_size,
220
+ stride=1,
221
+ padding=(kernel_size - 1) // 2,
222
+ ),
223
+ nn.BatchNorm1d(in_channels),
224
+ Swish(),
225
+ PointwiseConv1d(in_channels, in_channels, stride=1, padding=0, bias=True),
226
+ nn.Dropout(p=dropout_p),
227
+ )
228
+
229
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
230
+ return self.sequential(inputs).transpose(1, 2)
231
+
232
+
233
+ class FramewiseConv2dSubampling(nn.Module):
234
+ def __init__(self, out_channels: int, subsample_rate: int = 2) -> None:
235
+ super(FramewiseConv2dSubampling, self).__init__()
236
+ assert subsample_rate in {2, 4}, "subsample_rate should be 2 or 4"
237
+ self.subsample_rate = subsample_rate
238
+ self.cnn = nn.Sequential(
239
+ nn.Conv2d(1, out_channels, kernel_size=3, stride=2),
240
+ nn.ReLU(),
241
+ nn.Conv2d(
242
+ out_channels,
243
+ out_channels,
244
+ kernel_size=3,
245
+ stride=(2 if subsample_rate == 4 else 1, 2),
246
+ padding=(0 if subsample_rate == 4 else 1, 0),
247
+ ),
248
+ nn.ReLU(),
249
+ )
250
+
251
+ def forward(
252
+ self, inputs: torch.Tensor, input_lengths: torch.LongTensor
253
+ ) -> Tuple[torch.Tensor, torch.LongTensor]:
254
+ # inputs: (B, T, C) -> (B, 1, T, C)
255
+ if self.subsample_rate == 2 and inputs.shape[1] % 2 == 0:
256
+ inputs = F.pad(inputs, (0, 0, 0, 1), "constant", 0)
257
+ outputs = self.cnn(inputs.unsqueeze(1))
258
+ batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size()
259
+
260
+ outputs = outputs.permute(0, 2, 1, 3)
261
+ outputs = outputs.contiguous().view(
262
+ batch_size, subsampled_lengths, channels * sumsampled_dim
263
+ )
264
+
265
+ if self.subsample_rate == 4:
266
+ output_lengths = (((input_lengths - 1) >> 1) - 1) >> 1
267
+ else:
268
+ output_lengths = input_lengths >> 1
269
+
270
+ return outputs, output_lengths
271
+
272
+
273
+ class PatchwiseConv2dSubampling(nn.Module):
274
+ def __init__(
275
+ self,
276
+ mel_dim: int,
277
+ out_channels: int,
278
+ patch_size_time: int = 16,
279
+ patch_size_freq: int = 16,
280
+ ) -> None:
281
+ super(PatchwiseConv2dSubampling, self).__init__()
282
+
283
+ self.mel_dim = mel_dim
284
+ self.patch_size_time = patch_size_time
285
+ self.patch_size_freq = patch_size_freq
286
+
287
+ self.proj = nn.Conv2d(
288
+ 1,
289
+ out_channels,
290
+ kernel_size=(patch_size_time, patch_size_freq),
291
+ stride=(patch_size_time, patch_size_freq),
292
+ padding=0,
293
+ )
294
+ self.cnn = nn.Sequential(
295
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
296
+ nn.ReLU(),
297
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
298
+ nn.ReLU(),
299
+ )
300
+
301
+ @property
302
+ def subsample_rate(self) -> int:
303
+ return self.patch_size_time * self.patch_size_freq // self.mel_dim
304
+
305
+ def forward(
306
+ self, inputs: torch.Tensor, input_lengths: torch.LongTensor
307
+ ) -> Tuple[torch.Tensor, torch.LongTensor]:
308
+ assert (
309
+ inputs.shape[2] == self.mel_dim
310
+ ), "inputs.shape[2] should be equal to mel_dim"
311
+
312
+ # inputs: (B, Time, Freq) -> (B, 1, Time, Freq)
313
+ outputs = self.proj(inputs.unsqueeze(1))
314
+ outputs = self.cnn(outputs)
315
+ # (B, channels, Time // patch_size_time, Freq // patch_size_freq)
316
+ outputs = outputs.flatten(2, 3).transpose(1, 2)
317
+ # (B, (Time // patch_size_time) * (Freq // patch_size_freq), channels)
318
+
319
+ output_lengths = (
320
+ input_lengths
321
+ // self.patch_size_time
322
+ * (self.mel_dim // self.patch_size_freq)
323
+ )
324
+
325
+ return outputs, output_lengths
326
+
327
+
328
+ class RelPositionalEncoding(nn.Module):
329
+ def __init__(self, d_model: int, max_len: int = 10000) -> None:
330
+ super(RelPositionalEncoding, self).__init__()
331
+ self.d_model = d_model
332
+ self.pe = None
333
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
334
+
335
+ def extend_pe(self, x: torch.Tensor) -> None:
336
+ if self.pe is not None:
337
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
338
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
339
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
340
+ return
341
+
342
+ pe_positive = torch.zeros(x.size(1), self.d_model)
343
+ pe_negative = torch.zeros(x.size(1), self.d_model)
344
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
345
+ div_term = torch.exp(
346
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
347
+ * -(math.log(10000.0) / self.d_model)
348
+ )
349
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
350
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
351
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
352
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
353
+
354
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
355
+ pe_negative = pe_negative[1:].unsqueeze(0)
356
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
357
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
358
+
359
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
360
+ # x: (B, T, C)
361
+ self.extend_pe(x)
362
+ pos_emb = self.pe[
363
+ :,
364
+ self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1),
365
+ ]
366
+ return pos_emb
367
+
368
+
369
+ class RelativeMultiHeadAttention(nn.Module):
370
+ def __init__(
371
+ self,
372
+ d_model: int = 512,
373
+ num_heads: int = 16,
374
+ dropout_p: float = 0.1,
375
+ ):
376
+ super(RelativeMultiHeadAttention, self).__init__()
377
+ assert d_model % num_heads == 0, "d_model % num_heads should be zero."
378
+ self.d_model = d_model
379
+ self.d_head = int(d_model / num_heads)
380
+ self.num_heads = num_heads
381
+ self.sqrt_dim = math.sqrt(self.d_head)
382
+
383
+ self.query_proj = Linear(d_model, d_model)
384
+ self.key_proj = Linear(d_model, d_model)
385
+ self.value_proj = Linear(d_model, d_model)
386
+ self.pos_proj = Linear(d_model, d_model, bias=False)
387
+
388
+ self.dropout = nn.Dropout(p=dropout_p)
389
+ self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
390
+ self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
391
+ torch.nn.init.xavier_uniform_(self.u_bias)
392
+ torch.nn.init.xavier_uniform_(self.v_bias)
393
+
394
+ self.out_proj = Linear(d_model, d_model)
395
+
396
+ def forward(
397
+ self,
398
+ query: torch.Tensor,
399
+ key: torch.Tensor,
400
+ value: torch.Tensor,
401
+ pos_embedding: torch.Tensor,
402
+ mask: Optional[torch.Tensor] = None,
403
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
404
+ batch_size = value.size(0)
405
+
406
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
407
+ key = (
408
+ self.key_proj(key)
409
+ .view(batch_size, -1, self.num_heads, self.d_head)
410
+ .permute(0, 2, 1, 3)
411
+ )
412
+ value = (
413
+ self.value_proj(value)
414
+ .view(batch_size, -1, self.num_heads, self.d_head)
415
+ .permute(0, 2, 1, 3)
416
+ )
417
+ pos_embedding = self.pos_proj(pos_embedding).view(
418
+ batch_size, -1, self.num_heads, self.d_head
419
+ )
420
+
421
+ content_score = torch.matmul(
422
+ (query + self.u_bias).transpose(1, 2), key.transpose(2, 3)
423
+ )
424
+ pos_score = torch.matmul(
425
+ (query + self.v_bias).transpose(1, 2),
426
+ pos_embedding.permute(0, 2, 3, 1),
427
+ )
428
+ pos_score = self._relative_shift(pos_score)
429
+
430
+ score = (content_score + pos_score) / self.sqrt_dim
431
+
432
+ if mask is not None:
433
+ mask = mask.unsqueeze(1)
434
+ score.masked_fill_(mask, -1e9)
435
+
436
+ attn = F.softmax(score, -1)
437
+ attn = self.dropout(attn)
438
+
439
+ context = torch.matmul(attn, value).transpose(1, 2)
440
+ context = context.contiguous().view(batch_size, -1, self.d_model)
441
+
442
+ return self.out_proj(context), attn
443
+
444
+ def _relative_shift(self, pos_score: torch.Tensor) -> torch.Tensor:
445
+ batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
446
+ zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
447
+ padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
448
+
449
+ padded_pos_score = padded_pos_score.view(
450
+ batch_size, num_heads, seq_length2 + 1, seq_length1
451
+ )
452
+ pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)[
453
+ :, :, :, : seq_length2 // 2 + 1
454
+ ]
455
+
456
+ return pos_score
457
+
458
+
459
+ class MultiHeadedSelfAttentionModule(nn.Module):
460
+ def __init__(self, d_model: int, num_heads: int, dropout_p: float = 0.1):
461
+ super(MultiHeadedSelfAttentionModule, self).__init__()
462
+ self.positional_encoding = RelPositionalEncoding(d_model)
463
+ self.layer_norm = nn.LayerNorm(d_model)
464
+ self.attention = RelativeMultiHeadAttention(d_model, num_heads, dropout_p)
465
+ self.dropout = nn.Dropout(p=dropout_p)
466
+
467
+ def forward(
468
+ self, inputs: torch.Tensor, mask: Optional[torch.Tensor] = None
469
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
470
+ batch_size = inputs.size(0)
471
+ pos_embedding = self.positional_encoding(inputs)
472
+ pos_embedding = pos_embedding.repeat(batch_size, 1, 1)
473
+
474
+ inputs = self.layer_norm(inputs)
475
+ outputs, attn = self.attention(
476
+ inputs, inputs, inputs, pos_embedding=pos_embedding, mask=mask
477
+ )
478
+
479
+ return self.dropout(outputs), attn
480
+
481
+
482
+ class ConformerBlock(nn.Module):
483
+ def __init__(
484
+ self,
485
+ encoder_dim: int = 512,
486
+ attention_type: str = "mhsa",
487
+ num_attention_heads: int = 8,
488
+ mamba_d_state: int = 16,
489
+ mamba_d_conv: int = 4,
490
+ mamba_expand: int = 2,
491
+ mamba_bidirectional: bool = True,
492
+ feed_forward_expansion_factor: int = 4,
493
+ conv_expansion_factor: int = 2,
494
+ feed_forward_dropout_p: float = 0.1,
495
+ attention_dropout_p: float = 0.1,
496
+ conv_dropout_p: float = 0.1,
497
+ conv_kernel_size: int = 31,
498
+ half_step_residual: bool = True,
499
+ transformer_style: bool = False,
500
+ ):
501
+ super(ConformerBlock, self).__init__()
502
+
503
+ self.transformer_style = transformer_style
504
+ self.attention_type = attention_type
505
+
506
+ if half_step_residual and not transformer_style:
507
+ self.feed_forward_residual_factor = 0.5
508
+ else:
509
+ self.feed_forward_residual_factor = 1
510
+
511
+ assert attention_type in ["mhsa", "mamba"]
512
+ if attention_type == "mhsa":
513
+ attention = MultiHeadedSelfAttentionModule(
514
+ d_model=encoder_dim,
515
+ num_heads=num_attention_heads,
516
+ dropout_p=attention_dropout_p,
517
+ )
518
+
519
+ self.ffn_1 = FeedForwardModule(
520
+ encoder_dim=encoder_dim,
521
+ expansion_factor=feed_forward_expansion_factor,
522
+ dropout_p=feed_forward_dropout_p,
523
+ )
524
+ self.attention = attention
525
+ if not transformer_style:
526
+ self.conv = ConformerConvModule(
527
+ in_channels=encoder_dim,
528
+ kernel_size=conv_kernel_size,
529
+ expansion_factor=conv_expansion_factor,
530
+ dropout_p=conv_dropout_p,
531
+ )
532
+ self.ffn_2 = FeedForwardModule(
533
+ encoder_dim=encoder_dim,
534
+ expansion_factor=feed_forward_expansion_factor,
535
+ dropout_p=feed_forward_dropout_p,
536
+ )
537
+ self.layernorm = nn.LayerNorm(encoder_dim)
538
+
539
+ def forward(
540
+ self, x: torch.Tensor
541
+ ) -> Tuple[torch.Tensor, Dict[str, Union[torch.Tensor, None]]]:
542
+ # FFN 1
543
+ ffn_1_out = self.ffn_1(x)
544
+ x = ffn_1_out * self.feed_forward_residual_factor + x
545
+
546
+ # Attention
547
+ if not isinstance(self.attention, MultiHeadedSelfAttentionModule):
548
+ # MAMBA
549
+ attn_out = self.attention(x)
550
+ attn = None
551
+ else:
552
+ attn_out, attn = self.attention(x)
553
+ x = attn_out + x
554
+
555
+ if self.transformer_style:
556
+ x = self.layernorm(x)
557
+ return x, {
558
+ "ffn_1": ffn_1_out,
559
+ "attn": attn,
560
+ "conv": None,
561
+ "ffn_2": None,
562
+ }
563
+
564
+ # Convolution
565
+ conv_out = self.conv(x)
566
+ x = conv_out + x
567
+
568
+ # FFN 2
569
+ ffn_2_out = self.ffn_2(x)
570
+ x = ffn_2_out * self.feed_forward_residual_factor + x
571
+ x = self.layernorm(x)
572
+
573
+ other = {
574
+ "ffn_1": ffn_1_out,
575
+ "attn": attn,
576
+ "conv": conv_out,
577
+ "ffn_2": ffn_2_out,
578
+ }
579
+
580
+ return x, other
581
+
582
+
583
+ class ConformerEncoder(nn.Module):
584
+ def __init__(self, cfg):
585
+ super(ConformerEncoder, self).__init__()
586
+
587
+ self.cfg = cfg
588
+ self.framewise_subsample = None
589
+ self.patchwise_subsample = None
590
+ self.framewise_in_proj = None
591
+ self.patchwise_in_proj = None
592
+ assert (
593
+ cfg.use_framewise_subsample or cfg.use_patchwise_subsample
594
+ ), "At least one subsampling method should be used"
595
+ if cfg.use_framewise_subsample:
596
+ self.framewise_subsample = FramewiseConv2dSubampling(
597
+ out_channels=cfg.conv_subsample_channels,
598
+ subsample_rate=cfg.conv_subsample_rate,
599
+ )
600
+ self.framewise_in_proj = nn.Sequential(
601
+ Linear(
602
+ cfg.conv_subsample_channels * (((cfg.input_dim - 1) // 2 - 1) // 2),
603
+ cfg.encoder_dim,
604
+ ),
605
+ nn.Dropout(p=cfg.input_dropout_p),
606
+ )
607
+ if cfg.use_patchwise_subsample:
608
+ self.patchwise_subsample = PatchwiseConv2dSubampling(
609
+ mel_dim=cfg.input_dim,
610
+ out_channels=cfg.conv_subsample_channels,
611
+ patch_size_time=cfg.patch_size_time,
612
+ patch_size_freq=cfg.patch_size_freq,
613
+ )
614
+ self.patchwise_in_proj = nn.Sequential(
615
+ Linear(
616
+ cfg.conv_subsample_channels,
617
+ cfg.encoder_dim,
618
+ ),
619
+ nn.Dropout(p=cfg.input_dropout_p),
620
+ )
621
+ assert not cfg.use_framewise_subsample or (
622
+ cfg.conv_subsample_rate == self.patchwise_subsample.subsample_rate
623
+ ), (
624
+ f"conv_subsample_rate ({cfg.conv_subsample_rate}) != patchwise_subsample.subsample_rate"
625
+ f"({self.patchwise_subsample.subsample_rate})"
626
+ )
627
+
628
+ self.framewise_norm, self.patchwise_norm = None, None
629
+ if getattr(cfg, "subsample_normalization", False):
630
+ if cfg.use_framewise_subsample:
631
+ self.framewise_norm = nn.LayerNorm(cfg.encoder_dim)
632
+ if cfg.use_patchwise_subsample:
633
+ self.patchwise_norm = nn.LayerNorm(cfg.encoder_dim)
634
+
635
+ self.conv_pos = None
636
+ if getattr(cfg, "conv_pos", False):
637
+ num_pos_layers = cfg.conv_pos_depth
638
+ k = max(3, cfg.conv_pos_width // num_pos_layers)
639
+ self.conv_pos = nn.Sequential(
640
+ TransposeLast(),
641
+ *[
642
+ nn.Sequential(
643
+ nn.Conv1d(
644
+ cfg.encoder_dim,
645
+ cfg.encoder_dim,
646
+ kernel_size=k,
647
+ padding=k // 2,
648
+ groups=cfg.conv_pos_groups,
649
+ ),
650
+ SamePad(k),
651
+ TransposeLast(),
652
+ nn.LayerNorm(cfg.encoder_dim, elementwise_affine=False),
653
+ TransposeLast(),
654
+ nn.GELU(),
655
+ )
656
+ for _ in range(num_pos_layers)
657
+ ],
658
+ TransposeLast(),
659
+ )
660
+ self.conv_pos_post_ln = nn.LayerNorm(cfg.encoder_dim)
661
+
662
+ self.layers = nn.ModuleList(
663
+ [
664
+ ConformerBlock(
665
+ encoder_dim=cfg.encoder_dim,
666
+ attention_type=cfg.attention_type,
667
+ num_attention_heads=cfg.num_attention_heads,
668
+ mamba_d_state=cfg.mamba_d_state,
669
+ mamba_d_conv=cfg.mamba_d_conv,
670
+ mamba_expand=cfg.mamba_expand,
671
+ mamba_bidirectional=cfg.mamba_bidirectional,
672
+ feed_forward_expansion_factor=cfg.feed_forward_expansion_factor,
673
+ conv_expansion_factor=cfg.conv_expansion_factor,
674
+ feed_forward_dropout_p=cfg.feed_forward_dropout_p,
675
+ attention_dropout_p=cfg.attention_dropout_p,
676
+ conv_dropout_p=cfg.conv_dropout_p,
677
+ conv_kernel_size=cfg.conv_kernel_size,
678
+ half_step_residual=cfg.half_step_residual,
679
+ transformer_style=getattr(cfg, "transformer_style", False),
680
+ )
681
+ for _ in range(cfg.num_layers)
682
+ ]
683
+ )
684
+
685
+ def count_parameters(self) -> int:
686
+ """Count parameters of encoder"""
687
+ return sum([p.numel() for p in self.parameters() if p.requires_grad])
688
+
689
+ def update_dropout(self, dropout_p: float) -> None:
690
+ """Update dropout probability of encoder"""
691
+ for name, child in self.named_children():
692
+ if isinstance(child, nn.Dropout):
693
+ child.p = dropout_p
694
+
695
+ def forward(
696
+ self,
697
+ inputs: torch.Tensor,
698
+ input_lengths: Optional[torch.Tensor] = None,
699
+ return_hidden: bool = False,
700
+ freeze_input_layers: bool = False,
701
+ target_layer: Optional[int] = None,
702
+ ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, List[torch.Tensor]]]:
703
+ if input_lengths is None:
704
+ input_lengths = torch.full(
705
+ (inputs.size(0),),
706
+ inputs.size(1),
707
+ dtype=torch.long,
708
+ device=inputs.device,
709
+ )
710
+
711
+ with torch.no_grad() if freeze_input_layers else contextlib.ExitStack():
712
+ frame_feat, patch_feat = None, None
713
+ if self.framewise_subsample is not None:
714
+ frame_feat, frame_lengths = self.framewise_subsample(
715
+ inputs, input_lengths
716
+ )
717
+ frame_feat = self.framewise_in_proj(frame_feat)
718
+ if self.framewise_norm is not None:
719
+ frame_feat = self.framewise_norm(frame_feat)
720
+
721
+ if self.patchwise_subsample is not None:
722
+ patch_feat, patch_lengths = self.patchwise_subsample(
723
+ inputs, input_lengths
724
+ )
725
+ patch_feat = self.patchwise_in_proj(patch_feat)
726
+ if self.patchwise_norm is not None:
727
+ patch_feat = self.patchwise_norm(patch_feat)
728
+
729
+ if frame_feat is not None and patch_feat is not None:
730
+ min_len = min(frame_feat.size(1), patch_feat.size(1))
731
+ frame_feat = frame_feat[:, :min_len]
732
+ patch_feat = patch_feat[:, :min_len]
733
+
734
+ features = frame_feat + patch_feat
735
+ output_lengths = (
736
+ frame_lengths
737
+ if frame_lengths.max().item() < patch_lengths.max().item()
738
+ else patch_lengths
739
+ )
740
+ elif frame_feat is not None:
741
+ features = frame_feat
742
+ output_lengths = frame_lengths
743
+ else:
744
+ features = patch_feat
745
+ output_lengths = patch_lengths
746
+
747
+ if self.conv_pos is not None:
748
+ features = features + self.conv_pos(features)
749
+ features = self.conv_pos_post_ln(features)
750
+
751
+ layer_results = defaultdict(list)
752
+
753
+ outputs = features
754
+ for i, layer in enumerate(self.layers):
755
+ outputs, other = layer(outputs)
756
+ if return_hidden:
757
+ layer_results["hidden_states"].append(outputs)
758
+ for k, v in other.items():
759
+ layer_results[k].append(v)
760
+
761
+ if target_layer is not None and i == target_layer:
762
+ break
763
+
764
+ return outputs, output_lengths, layer_results