automodel_remote_code_support

#2
by MCplayer - opened
README.md CHANGED
@@ -34,7 +34,7 @@ import torchaudio
34
  from transformers import AutoFeatureExtractor, AutoModel
35
 
36
  # 1. Load the feature extractor and the codec model
37
- model_id = "fnlp/XY_Tokenizer_TTSD_V0_hf"
38
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_id, trust_remote_code=True)
39
  codec = AutoModel.from_pretrained(model_id, trust_remote_code=True).eval().to("cuda")
40
 
@@ -48,7 +48,6 @@ if sampling_rate != 16000:
48
  input_features = feature_extractor(wav_form, sampling_rate=16000, return_attention_mask=True, return_tensors="pt")
49
  # The 'code' dictionary contains the discrete audio codes
50
  code = codec.encode(input_features)
51
- print(code)
52
 
53
  # 4. Decode the codes back to an audio waveform
54
  # The output is high-quality 24kHz audio.
 
34
  from transformers import AutoFeatureExtractor, AutoModel
35
 
36
  # 1. Load the feature extractor and the codec model
37
+ model_id = "OpenMOSS-Team/XY_Tokenizer_TTSD_V0_hf"
38
  feature_extractor = AutoFeatureExtractor.from_pretrained(model_id, trust_remote_code=True)
39
  codec = AutoModel.from_pretrained(model_id, trust_remote_code=True).eval().to("cuda")
40
 
 
48
  input_features = feature_extractor(wav_form, sampling_rate=16000, return_attention_mask=True, return_tensors="pt")
49
  # The 'code' dictionary contains the discrete audio codes
50
  code = codec.encode(input_features)
 
51
 
52
  # 4. Decode the codes back to an audio waveform
53
  # The output is high-quality 24kHz audio.
config.json CHANGED
@@ -1,5 +1,10 @@
1
  {
2
  "model_type": "xy_tokenizer",
 
 
 
 
 
3
  "input_sample_rate": 16000,
4
  "output_sample_rate": 24000,
5
  "encoder_downsample_rate": 1280,
 
1
  {
2
  "model_type": "xy_tokenizer",
3
+ "auto_map": {
4
+ "AutoFeatureExtractor": "feature_extraction_xy_tokenizer.XYTokenizerFeatureExtractor",
5
+ "AutoConfig": "configuration_xy_tokenizer.XYTokenizerConfig",
6
+ "AutoModel": "modeling_xy_tokenizer.XYTokenizerModel"
7
+ },
8
  "input_sample_rate": 16000,
9
  "output_sample_rate": 24000,
10
  "encoder_downsample_rate": 1280,
configuration_xy_tokenizer.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """XYTokenizer model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+
23
+ class XYTokenizerConfig(PretrainedConfig):
24
+ r"""
25
+ This is the configuration class to store the configuration of a [`XYTokenizerModel`]. It is used to instantiate a
26
+ XY Tokenizer model according to the specified arguments, defining the model architecture.
27
+
28
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
29
+ documentation from [`PretrainedConfig`] for more information.
30
+
31
+ Args:
32
+ input_sample_rate (`int`, *optional*, defaults to 16000):
33
+ The sampling rate of the input audio.
34
+ output_sample_rate (`int`, *optional*, defaults to 16000):
35
+ The sampling rate of the output audio.
36
+ encoder_downsample_rate (`int`, *optional*, defaults to 1280):
37
+ The total downsampling factor of the encoder part.
38
+ decoder_upsample_rate (`int`, *optional*, defaults to 1920):
39
+ The total upsampling factor of the decoder part.
40
+ code_dim (`int`, *optional*, defaults to 1280):
41
+ The dimension of the code embeddings.
42
+
43
+ // ... (All other parameters from the original YAML/dict config would be listed here) ...
44
+ // For brevity, we will define them with default values based on the provided code.
45
+
46
+ Example:
47
+ semantic_encoder_d_model (`int`, *optional*, defaults to 1280):
48
+ Hidden dimension for the semantic encoder.
49
+ num_quantizers (`int`, *optional*, defaults to 32):
50
+ Number of residual quantizers.
51
+ ...
52
+ """
53
+ model_type = "xy_tokenizer"
54
+
55
+ # A comprehensive config would flatten all nested kwargs from the original `generator_params`.
56
+ # For this example, we will create a simplified version. A real implementation would
57
+ # have all parameters explicitly defined here.
58
+ def __init__(
59
+ self,
60
+ input_sample_rate=16000,
61
+ output_sample_rate=16000,
62
+ encoder_downsample_rate=1280,
63
+ decoder_upsample_rate=1920,
64
+ code_dim=1280,
65
+ # A real config would have dozens of parameters here.
66
+ # We will dynamically accept them via **kwargs.
67
+ **kwargs,
68
+ ):
69
+ self.input_sample_rate = input_sample_rate
70
+ self.output_sample_rate = output_sample_rate
71
+ self.encoder_downsample_rate = encoder_downsample_rate
72
+ self.decoder_upsample_rate = decoder_upsample_rate
73
+ self.code_dim = code_dim
74
+
75
+ # Store all other parameters dynamically. This is a shortcut.
76
+ # A production-ready config should list all parameters explicitly.
77
+ self.params = kwargs
78
+
79
+ super().__init__(**kwargs)
80
+
81
+
82
+ __all__ = ["XYTokenizerConfig"]
feature_extraction_xy_tokenizer.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Feature extractor class for Whisper
17
+ """
18
+ import math
19
+ from functools import partial
20
+ from typing import List, Optional, Union
21
+ from collections import deque
22
+
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from transformers import WhisperFeatureExtractor
26
+ from transformers.audio_utils import mel_filter_bank
27
+ from transformers.configuration_utils import PretrainedConfig
28
+ from transformers.feature_extraction_utils import BatchFeature
29
+ from transformers.utils import TensorType, logging
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ class ExtractorIterator:
35
+ def __init__(
36
+ self,
37
+ data,
38
+ batch_size=8,
39
+ chunk_length=30,
40
+ overlap_seconds=10,
41
+ overlap_side="both",
42
+ sampling_rate=16000,
43
+ encode_func = None,
44
+ ) -> None:
45
+ self.data = data
46
+ self.batch_size = batch_size
47
+ self.chunk_length = chunk_length
48
+ self.overlap_seconds = overlap_seconds
49
+ self.overlap_side = overlap_side
50
+ self.sampling_rate = sampling_rate
51
+
52
+ # duration_size 是每次处理的有效音频长度
53
+ self.chunk_size = int(self.chunk_length * self.sampling_rate)
54
+ self.overlap_size = int(self.overlap_seconds * self.sampling_rate)
55
+ self.duration_size = self.chunk_size - self.overlap_size
56
+ assert (
57
+ (overlap_side == "right") or (self.overlap_size % 2 == 0)
58
+ ), '`overlap_seconds` must be divisible by 2 when `overlap_side` is "both".'
59
+ # 注意:这里我们只处理不带重叠的块,重叠将在外部处理(如果需要)
60
+ # 或者在迭代器内部更明确地处理。为了简化,我们假设分块是基于 duration_size
61
+
62
+ assert callable(encode_func)
63
+ self.encode_func = encode_func
64
+
65
+ def __iter__(self):
66
+ """
67
+ 返回一个生成器,该生成器负责处理所有批处理逻辑。
68
+ 这是最 Pythonic 的实现方式。
69
+ """
70
+ # 批处理相关的变量现在是 __iter__ 的局部变量,非常清晰
71
+ batch_num = 0
72
+
73
+ # 注意:chunk_and_pad_view 输出的块大小是 duration_size
74
+ wav_tensor = torch.zeros(self.batch_size, 1, self.chunk_size)
75
+ input_lengths = deque(maxlen=self.batch_size)
76
+ input_seq_no = torch.zeros(self.batch_size, dtype=torch.long)
77
+
78
+ right_boundary = self.get_right_boundary()
79
+
80
+ for i, sample in enumerate(self.data):
81
+ sample_chunks, sample_lengths, sample_seq_no = self.chunk_and_pad_view(sample, i)
82
+
83
+ processed_in_sample = 0
84
+ while processed_in_sample < len(sample_chunks):
85
+ space_in_batch = self.batch_size - batch_num
86
+ chunks_to_add = min(space_in_batch, len(sample_chunks) - processed_in_sample)
87
+
88
+ # 定义切片范围
89
+ start_idx_sample = processed_in_sample
90
+ end_idx_sample = processed_in_sample + chunks_to_add
91
+ start_idx_batch = batch_num
92
+ end_idx_batch = batch_num + chunks_to_add
93
+
94
+ # 填充数据
95
+ wav_tensor[start_idx_batch:end_idx_batch] = sample_chunks[start_idx_sample:end_idx_sample]
96
+ input_lengths.extend(sample_lengths[start_idx_sample:end_idx_sample])
97
+ input_seq_no[start_idx_batch:end_idx_batch] = sample_seq_no[start_idx_sample:end_idx_sample]
98
+
99
+ # 更新计数器
100
+ batch_num += chunks_to_add
101
+ processed_in_sample += chunks_to_add
102
+
103
+ # 如果批次满了,yield 一个副本并重置
104
+ if batch_num == self.batch_size:
105
+ list_x = []
106
+ for xi, (_, right) in enumerate(input_lengths):
107
+ if right == right_boundary and torch.any(wav_tensor[xi, :, right:] != 0):
108
+ list_x.append(wav_tensor[xi].reshape(-1).cpu().numpy())
109
+ else:
110
+ list_x.append(wav_tensor[xi, :, :right].reshape(-1).cpu().numpy())
111
+
112
+ yield BatchFeature({
113
+ **self.encode_func(list_x),
114
+ "input_lengths": input_lengths,
115
+ "chunk_seq_no": input_seq_no.clone(),
116
+ })
117
+
118
+ # 重置批次计数器和Tensor内容
119
+ batch_num = 0
120
+ wav_tensor.zero_()
121
+ input_lengths.clear()
122
+ input_seq_no.zero_()
123
+
124
+ # 循环结束后,处理最后一个未满的批次
125
+ if batch_num > 0:
126
+ list_x = []
127
+ for xi in range(batch_num):
128
+ _, right = input_lengths[xi]
129
+ if right == right_boundary and torch.any(wav_tensor[xi, :, right:] != 0):
130
+ list_x.append(wav_tensor[xi].reshape(-1).cpu().numpy())
131
+ else:
132
+ list_x.append(wav_tensor[xi, :, :right].reshape(-1).cpu().numpy())
133
+ yield BatchFeature({
134
+ **self.encode_func(list_x),
135
+ "input_lengths": input_lengths,
136
+ "chunk_seq_no": input_seq_no[:batch_num].clone(),
137
+ })
138
+
139
+ def chunk_and_pad_view(self, tensor, seq_no):
140
+ x = tensor[0:1, :].unsqueeze(0)
141
+
142
+ stride = self.duration_size
143
+ kernel = self.chunk_size
144
+ B, C, L = x.shape
145
+
146
+ num_chunks = max(0, math.ceil((L - kernel) / stride)) + 1
147
+ target_len = (num_chunks - 1) * stride + kernel
148
+ padding_size = max(0, target_len - L)
149
+ x_padded = F.pad(x, (0, padding_size), "constant", 0)
150
+ output_tensor = x_padded.unfold(dimension=2, size=kernel, step=stride).squeeze(0).transpose(0, 1)
151
+
152
+ output_lengths = self.get_windows_boundaries(num_chunks, L)
153
+ output_seq_no = torch.full((num_chunks,), seq_no, dtype=torch.long)
154
+ return output_tensor, output_lengths, output_seq_no
155
+
156
+ def get_left_boundary(self):
157
+ if self.overlap_side == "right":
158
+ return 0
159
+ else:
160
+ return int(self.overlap_size / 2)
161
+
162
+ def get_right_boundary(self):
163
+ if self.overlap_side == "right":
164
+ return self.duration_size
165
+ else:
166
+ return self.chunk_size - int(self.overlap_size / 2)
167
+
168
+ def get_windows_boundaries(self, num_chunks, seq_len):
169
+ left_boundary = self.get_left_boundary()
170
+ right_boundary = self.get_right_boundary()
171
+
172
+ output_lengths = [(left_boundary, right_boundary) for _ in range(num_chunks)]
173
+ output_lengths[0] = (0, output_lengths[0][1])
174
+ output_lengths[-1] = (output_lengths[-1][0], seq_len - self.duration_size * (num_chunks-1))
175
+ return output_lengths
176
+
177
+
178
+ class XYTokenizerFeatureExtractor(WhisperFeatureExtractor):
179
+ def __init__(
180
+ self,
181
+ feature_size=80,
182
+ sampling_rate=16000,
183
+ hop_length=160,
184
+ chunk_length=30,
185
+ n_fft=400,
186
+ n_samples=480000,
187
+ nb_max_frames=3000,
188
+ padding_side="right",
189
+ padding_value=0.0,
190
+ dither=0.0,
191
+ return_attention_mask=False,
192
+ max_frequency=None,
193
+ batch_size=8,
194
+ overlap_side="both",
195
+ **kwargs,
196
+ ):
197
+ super().__init__(
198
+ feature_size=feature_size,
199
+ sampling_rate=sampling_rate,
200
+ hop_length=hop_length,
201
+ chunk_length=chunk_length,
202
+ n_fft=n_fft,
203
+ padding_value=padding_value,
204
+ dither=dither,
205
+ return_attention_mask=return_attention_mask,
206
+ n_samples=n_samples,
207
+ nb_max_frames=nb_max_frames,
208
+ padding_side=padding_side,
209
+ **kwargs,
210
+ )
211
+ self.max_frequency = max_frequency if max_frequency is not None else sampling_rate / 2
212
+ self.batch_size = batch_size
213
+ self.mel_filters = mel_filter_bank(
214
+ num_frequency_bins=1 + n_fft // 2,
215
+ num_mel_filters=feature_size,
216
+ min_frequency=0.0,
217
+ max_frequency=self.max_frequency,
218
+ sampling_rate=sampling_rate,
219
+ norm="slaney",
220
+ mel_scale="slaney",
221
+ )
222
+ self.overlap_side = overlap_side
223
+
224
+ def __call__(
225
+ self,
226
+ raw_speech: Union[torch.Tensor, List[torch.Tensor]],
227
+ truncation: bool = True,
228
+ pad_to_multiple_of: Optional[int] = None,
229
+ return_tensors: Optional[Union[str, TensorType]] = None,
230
+ return_attention_mask: Optional[bool] = None,
231
+ padding: Optional[str] = "max_length",
232
+ max_length: Optional[int] = None,
233
+ sampling_rate: Optional[int] = None,
234
+ do_normalize: Optional[bool] = None,
235
+ device: Optional[str] = "cpu",
236
+ return_token_timestamps: Optional[bool] = None,
237
+ overlap_seconds: int = 10,
238
+ **kwargs,
239
+ ) -> ExtractorIterator:
240
+
241
+ if not isinstance(raw_speech, list):
242
+ raw_speech = [raw_speech]
243
+
244
+ return ExtractorIterator(
245
+ raw_speech,
246
+ batch_size=self.batch_size if self.batch_size else len(raw_speech),
247
+ chunk_length=self.chunk_length,
248
+ overlap_seconds=overlap_seconds,
249
+ overlap_side=self.overlap_side,
250
+ sampling_rate=self.sampling_rate,
251
+ encode_func=partial(
252
+ super().__call__,
253
+ truncation=truncation,
254
+ pad_to_multiple_of=pad_to_multiple_of,
255
+ return_tensors=return_tensors,
256
+ return_attention_mask=return_attention_mask,
257
+ padding=padding,
258
+ max_length=max_length,
259
+ sampling_rate=sampling_rate,
260
+ do_normalize=do_normalize,
261
+ device=device,
262
+ return_token_timestamps=return_token_timestamps,
263
+ **kwargs,
264
+ )
265
+ )
modeling_xy_tokenizer.py ADDED
@@ -0,0 +1,1243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Transformers XYTokenizer model."""
16
+
17
+ import math
18
+ from collections import defaultdict
19
+ from dataclasses import asdict, dataclass
20
+ from typing import Optional, Tuple, Union, List
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.distributed as dist
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ from einops import rearrange
28
+ from torch.nn.utils.parametrizations import weight_norm
29
+ from transformers.activations import ACT2FN
30
+ from transformers.modeling_utils import PreTrainedAudioTokenizerBase
31
+ from transformers.utils import ModelOutput, logging
32
+ from transformers.feature_extraction_utils import BatchFeature
33
+
34
+ from .configuration_xy_tokenizer import XYTokenizerConfig
35
+ from .feature_extraction_xy_tokenizer import ExtractorIterator
36
+
37
+ logger = logging.get_logger(__name__)
38
+ # ----------------------------------------------- #
39
+ # Model Output Dataclasses #
40
+ # ----------------------------------------------- #
41
+ @dataclass
42
+ class XYTokenizerEncodeOutput(ModelOutput):
43
+ """
44
+ Output type of [`XYTokenizerModel.encode`].
45
+
46
+ Args:
47
+ quantized_representation (`torch.FloatTensor` of shape `(batch_size, hidden_dim, sequence_length)`):
48
+ The quantized continuous representation of the input audio. This is the output of the quantizer.
49
+ audio_codes (`torch.LongTensor` of shape `(num_codebooks, batch_size, sequence_length)`):
50
+ The discrete codes from the quantizer for each codebook.
51
+ codes_lengths (`torch.LongTensor` of shape `(batch_size,)`):
52
+ The valid length of each sequence in `audio_codes`.
53
+ commit_loss (`torch.FloatTensor`, *optional*):
54
+ The commitment loss from the vector quantizer.
55
+ overlap_seconds (`int`, *optional*):
56
+ The duration of the overlap in seconds between adjacent audio chunks.
57
+ """
58
+ quantized_representation: torch.FloatTensor = None
59
+ audio_codes: torch.LongTensor = None
60
+ codes_lengths: torch.LongTensor = None
61
+ commit_loss: Optional[torch.FloatTensor] = None
62
+ overlap_seconds: Optional[int] = None
63
+
64
+
65
+ @dataclass
66
+ class XYTokenizerDecodeOutput(ModelOutput):
67
+ """
68
+ Output type of [`XYTokenizerModel.decode`].
69
+
70
+ Args:
71
+ audio_values (`torch.FloatTensor` of shape `(batch_size, 1, sequence_length)`):
72
+ The reconstructed audio waveform.
73
+ output_length (`torch.LongTensor` of shape `(batch_size,)`):
74
+ The valid length of each sequence in `audio_values`.
75
+ """
76
+ audio_values: torch.FloatTensor = None
77
+ output_length: Optional[torch.LongTensor] = None
78
+
79
+
80
+ @dataclass
81
+ class XYTokenizerModelOutput(ModelOutput):
82
+ """
83
+ Output type of [`XYTokenizerModel`]'s forward pass.
84
+
85
+ Args:
86
+ audio_values (`torch.FloatTensor` of shape `(batch_size, 1, sequence_length)`):
87
+ The reconstructed audio waveform.
88
+ output_length (`torch.LongTensor` of shape `(batch_size,)`):
89
+ The valid length of each sequence in `audio_values`.
90
+ quantized_representation (`torch.FloatTensor` of shape `(batch_size, hidden_dim, sequence_length)`):
91
+ The quantized continuous representation of the input audio. This is the output of the quantizer.
92
+ audio_codes (`torch.LongTensor` of shape `(num_codebooks, batch_size, sequence_length)`):
93
+ The discrete codes from the quantizer for each codebook.
94
+ codes_lengths (`torch.LongTensor` of shape `(batch_size,)`):
95
+ The valid length of each sequence in `audio_codes`.
96
+ commit_loss (`torch.FloatTensor`, *optional*):
97
+ The commitment loss from the vector quantizer.
98
+ """
99
+ audio_values: torch.FloatTensor = None
100
+ output_length: torch.LongTensor = None
101
+ quantized_representation: torch.FloatTensor = None
102
+ audio_codes: torch.LongTensor = None
103
+ codes_lengths: torch.LongTensor = None
104
+ commit_loss: Optional[torch.FloatTensor] = None
105
+
106
+
107
+ @dataclass
108
+ class VectorQuantizerConfig:
109
+ """Configuration for the VectorQuantize module."""
110
+ commitment: float = 1.0
111
+ decay: float = 0.99
112
+ epsilon: float = 1e-5
113
+ threshold_ema_dead: int = 2
114
+ kmeans_init: bool = True
115
+ kmeans_iters: int = 10
116
+
117
+
118
+ # ----------------------------------------------- #
119
+ # All Helper Modules (Copied from source) #
120
+ # ----------------------------------------------- #
121
+ def sinusoids(length, channels, max_timescale=10000, device=None):
122
+ assert channels % 2 == 0
123
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
124
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
125
+ scaled_time = torch.arange(length, device=device)[:, np.newaxis] * inv_timescales[np.newaxis, :]
126
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
127
+
128
+
129
+ def get_sequence_mask(inputs, inputs_length):
130
+ if inputs.dim() == 3:
131
+ bsz, tgt_len, _ = inputs.size()
132
+ else:
133
+ bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
134
+ sequence_mask = torch.arange(0, tgt_len, device=inputs.device)
135
+ sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
136
+ return sequence_mask
137
+
138
+
139
+ class RMSNorm(nn.Module):
140
+ def __init__(self, hidden_size, eps=1e-6):
141
+ super().__init__()
142
+ self.weight = nn.Parameter(torch.ones(hidden_size))
143
+ self.variance_epsilon = eps
144
+
145
+ def forward(self, hidden_states):
146
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
147
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
148
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
149
+ hidden_states = hidden_states.to(self.weight.dtype)
150
+ return self.weight * hidden_states
151
+
152
+
153
+ class VarLenAttention(nn.Module):
154
+ def __init__(self, embed_dim, num_heads, causal=False, dropout=0.0):
155
+ super().__init__()
156
+ self.embed_dim = embed_dim
157
+ self.num_heads = num_heads
158
+ self.head_dim = embed_dim // num_heads
159
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
160
+ self.causal = causal
161
+ self.dropout = nn.Dropout(dropout)
162
+ self.scaling = self.head_dim ** -0.5
163
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
164
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
165
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
166
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
167
+
168
+ def _create_attention_mask(self, seq_len, max_len, device, dtype):
169
+ bsz = seq_len.size(0)
170
+ mask = torch.ones(bsz, 1, max_len, max_len, device=device, dtype=dtype)
171
+ seq_indices = torch.arange(max_len, device=device).unsqueeze(0)
172
+ seq_len_expanded = seq_len.unsqueeze(1)
173
+ valid_mask = seq_indices < seq_len_expanded.unsqueeze(-1)
174
+ mask = mask * (valid_mask.unsqueeze(2) & valid_mask.unsqueeze(3)).to(dtype)
175
+ if self.causal:
176
+ causal_mask = torch.triu(torch.ones(max_len, max_len, device=device, dtype=torch.bool), diagonal=1)
177
+ mask = mask * (~causal_mask.unsqueeze(0).unsqueeze(1)).to(dtype)
178
+ mask = mask + (1.0 - mask) * torch.finfo(dtype).min
179
+ return mask
180
+
181
+ def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
182
+ bsz, max_len, _ = hidden_states.size()
183
+ query = self.q_proj(hidden_states) * self.scaling
184
+ key = self.k_proj(hidden_states)
185
+ value = self.v_proj(hidden_states)
186
+ query = query.view(bsz, max_len, self.num_heads, self.head_dim).transpose(1, 2)
187
+ key = key.view(bsz, max_len, self.num_heads, self.head_dim).transpose(1, 2)
188
+ value = value.view(bsz, max_len, self.num_heads, self.head_dim).transpose(1, 2)
189
+ attn_scores = torch.matmul(query, key.transpose(-1, -2))
190
+ attn_mask = self._create_attention_mask(seq_len, max_len, hidden_states.device, attn_scores.dtype)
191
+ attn_scores = attn_scores + attn_mask
192
+ attn_weights = F.softmax(attn_scores, dim=-1)
193
+ attn_weights = self.dropout(attn_weights)
194
+ attn_output = torch.matmul(attn_weights, value)
195
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, max_len, self.embed_dim)
196
+ attn_output = self.out_proj(attn_output)
197
+ return attn_output
198
+
199
+
200
+ class OmniWhisperMLP(nn.Module):
201
+ def __init__(self, activation_function="gelu", d_model=1280, ffn_dim=5120):
202
+ super().__init__()
203
+ self.activation_fn = ACT2FN[activation_function]
204
+ self.fc1 = nn.Linear(d_model, ffn_dim)
205
+ self.fc2 = nn.Linear(ffn_dim, d_model)
206
+
207
+ def forward(self, hidden_states):
208
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
209
+ return self.fc2(hidden_states)
210
+
211
+
212
+ class OmniWhisperTransformerLayer(nn.Module):
213
+ def __init__(self, activation_function="gelu", d_model=1280, attention_heads=20, ffn_dim=5120, causal=False, ln_type="LayerNorm", attn_type="varlen"):
214
+ super().__init__()
215
+ self.embed_dim = d_model
216
+ if attn_type != "varlen":
217
+ raise ValueError(f"Unknown attn_type: {attn_type}. Only 'varlen' is supported.")
218
+ self.self_attn = VarLenAttention(self.embed_dim, attention_heads, causal)
219
+ if ln_type == "LayerNorm":
220
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
221
+ elif ln_type == "RMSNorm":
222
+ self.self_attn_layer_norm = RMSNorm(self.embed_dim)
223
+ else:
224
+ raise ValueError(f"Unknown ln_type: {ln_type}")
225
+
226
+ self.mlp = OmniWhisperMLP(activation_function, d_model, ffn_dim)
227
+ if ln_type == "LayerNorm":
228
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
229
+ elif ln_type == "RMSNorm":
230
+ self.final_layer_norm = RMSNorm(self.embed_dim)
231
+ else:
232
+ raise ValueError(f"Unknown ln_type: {ln_type}")
233
+
234
+ def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
235
+ residual = hidden_states
236
+ hidden_states = self.self_attn_layer_norm(hidden_states)
237
+ hidden_states = self.self_attn(hidden_states, seq_len)
238
+ hidden_states = residual + hidden_states
239
+ residual = hidden_states
240
+ hidden_states = self.final_layer_norm(hidden_states)
241
+ hidden_states = self.mlp(hidden_states)
242
+ hidden_states = residual + hidden_states
243
+ if (hidden_states.dtype == torch.float16 or hidden_states.dtype == torch.bfloat16) and \
244
+ (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
245
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
246
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
247
+ return hidden_states
248
+
249
+
250
+ class OmniAudioEncoder(nn.Module):
251
+ def __init__(
252
+ self, num_mel_bins=128, sampling_rate=16000, hop_length=160, stride_size=2, kernel_size=3,
253
+ d_model=1280, scale_embedding=True, max_audio_seconds=30, encoder_layers=32,
254
+ encoder_attention_heads=20, encoder_ffn_dim=5120, activation_function="gelu", attn_type="varlen"
255
+ ):
256
+ super().__init__()
257
+ self.max_source_positions = (max_audio_seconds * sampling_rate // hop_length) // stride_size
258
+ self.embed_scale = math.sqrt(d_model) if scale_embedding else 1.0
259
+ self.num_mel_bins, self.d_model, self.stride_size = num_mel_bins, d_model, stride_size
260
+ self.conv1 = nn.Conv1d(num_mel_bins, d_model, kernel_size=kernel_size, padding=1)
261
+ self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=kernel_size, stride=stride_size, padding=1)
262
+ self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, d_model))
263
+ self.layers = nn.ModuleList([
264
+ OmniWhisperTransformerLayer(activation_function, d_model, encoder_attention_heads, encoder_ffn_dim, False, attn_type=attn_type)
265
+ for _ in range(encoder_layers)
266
+ ])
267
+ self.layer_norm = nn.LayerNorm(d_model)
268
+
269
+ def forward(self, input_features, input_length, output_hidden_states=False):
270
+ input_features = input_features.to(self.conv1.weight.dtype)
271
+ inputs_embeds = F.gelu(self.conv1(input_features))
272
+ inputs_embeds = F.gelu(self.conv2(inputs_embeds))
273
+ output_length = (input_length // self.stride_size).long()
274
+ hidden_states = inputs_embeds.permute(0, 2, 1)
275
+ bsz, tgt_len, _ = hidden_states.size()
276
+ pos_embed = self.positional_embedding[:tgt_len] if tgt_len < self.positional_embedding.shape[0] else self.positional_embedding
277
+ hidden_states = (hidden_states.to(torch.float32) + pos_embed).to(hidden_states.dtype)
278
+ attention_mask = get_sequence_mask(hidden_states, output_length)
279
+ all_hidden = () if output_hidden_states else None
280
+ for layer in self.layers:
281
+ if output_hidden_states:
282
+ all_hidden += (hidden_states,)
283
+ hidden_states = layer(hidden_states, output_length)
284
+ hidden_states = self.layer_norm(hidden_states)
285
+ if output_hidden_states:
286
+ all_hidden += (hidden_states,)
287
+ hidden_states = torch.where(attention_mask, hidden_states, 0).transpose(1, 2)
288
+ if not output_hidden_states:
289
+ return hidden_states, output_length
290
+ return hidden_states, output_length, all_hidden
291
+
292
+
293
+ class OmniAudioDecoder(nn.Module):
294
+ def __init__(
295
+ self, num_mel_bins=128, sampling_rate=16000, hop_length=160, stride_size=2, kernel_size=3,
296
+ d_model=1280, scale_embedding=True, max_audio_seconds=30, decoder_layers=32,
297
+ decoder_attention_heads=20, decoder_ffn_dim=5120, activation_function="gelu", attn_type="varlen"
298
+ ):
299
+ super().__init__()
300
+ self.max_source_positions = (max_audio_seconds * sampling_rate // hop_length) // stride_size
301
+ self.embed_scale = math.sqrt(d_model) if scale_embedding else 1.0
302
+ self.num_mel_bins, self.d_model, self.stride_size = num_mel_bins, d_model, stride_size
303
+ self.deconv1 = nn.ConvTranspose1d(d_model, d_model, kernel_size, stride_size, padding=0, output_padding=0)
304
+ self.deconv2 = nn.ConvTranspose1d(d_model, num_mel_bins, kernel_size, stride=1, padding=0)
305
+ self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, d_model))
306
+ self.layers = nn.ModuleList([
307
+ OmniWhisperTransformerLayer(activation_function, d_model, decoder_attention_heads, decoder_ffn_dim, False, attn_type=attn_type)
308
+ for _ in range(decoder_layers)
309
+ ])
310
+ self.layer_norm = nn.LayerNorm(d_model)
311
+
312
+ def forward(self, hidden_states, input_length):
313
+ hidden_states = hidden_states.transpose(1, 2)
314
+ bsz, tgt_len, _ = hidden_states.size()
315
+ pos_embed = self.positional_embedding[:tgt_len] if tgt_len < self.positional_embedding.shape[0] else self.positional_embedding
316
+ hidden_states = (hidden_states.to(torch.float32) + pos_embed).to(hidden_states.dtype)
317
+ attention_mask = get_sequence_mask(hidden_states, input_length)
318
+ for layer in self.layers:
319
+ hidden_states = layer(hidden_states, input_length)
320
+ hidden_states = self.layer_norm(hidden_states)
321
+ hidden_states = torch.where(attention_mask, hidden_states, 0).permute(0, 2, 1)
322
+ output_features = F.gelu(self.deconv1(hidden_states))
323
+ output_features = F.gelu(self.deconv2(output_features))
324
+ expected_length = tgt_len * self.stride_size
325
+ if output_features.size(2) > expected_length:
326
+ output_features = output_features[:, :, :expected_length]
327
+ output_length = input_length * self.stride_size
328
+ return output_features, output_length
329
+
330
+
331
+ class ResidualDownConv(nn.Module):
332
+ def __init__(self, d_model=1280, avg_pooler=4):
333
+ super().__init__()
334
+ self.d_model, self.avg_pooler = d_model, avg_pooler
335
+ self.intermediate_dim = d_model * avg_pooler
336
+ self.gate_proj = nn.Conv1d(d_model, self.intermediate_dim, avg_pooler, avg_pooler, bias=False)
337
+ self.up_proj = nn.Conv1d(d_model, self.intermediate_dim, avg_pooler, avg_pooler, bias=False)
338
+ self.down_proj = nn.Linear(self.intermediate_dim, self.intermediate_dim, bias=False)
339
+ self.act_fn = ACT2FN['silu']
340
+ self.layer_norm = nn.LayerNorm(self.intermediate_dim)
341
+
342
+ def forward(self, x, input_length):
343
+ output_length = input_length // self.avg_pooler
344
+ x = x.transpose(1, 2)
345
+ batch_size, seq_len, _ = x.shape
346
+ if seq_len % self.avg_pooler != 0:
347
+ pad_size = self.avg_pooler - seq_len % self.avg_pooler
348
+ x = F.pad(x, (0, 0, 0, pad_size), "constant", 0) # Pad sequence dim
349
+ xt = x.permute(0, 2, 1)
350
+ g, u = self.gate_proj(xt).permute(0, 2, 1), self.up_proj(xt).permute(0, 2, 1)
351
+ x = x.reshape(batch_size, -1, self.intermediate_dim)
352
+ c = self.down_proj(self.act_fn(g) * u)
353
+ res = self.layer_norm(c + x).transpose(1, 2)
354
+ return res, output_length
355
+
356
+
357
+ class UpConv(nn.Module):
358
+ def __init__(self, d_model=1280, stride=4):
359
+ super().__init__()
360
+ self.d_model, self.stride = d_model, stride
361
+ self.up_conv = nn.ConvTranspose1d(self.stride * d_model, d_model, stride, stride, bias=False)
362
+
363
+ def forward(self, x, input_length):
364
+ res = self.up_conv(x)
365
+ output_length = input_length * self.stride
366
+ return res, output_length
367
+
368
+
369
+ class Transformer(nn.Module):
370
+ def __init__(
371
+ self, input_dim=1280, d_model=1280, output_dim=1280, max_source_positions=1500,
372
+ encoder_layers=32, encoder_attention_heads=20, encoder_ffn_dim=5120,
373
+ activation_function="gelu", attn_type="varlen"
374
+ ):
375
+ super().__init__()
376
+ self.input_dim, self.d_model, self.output_dim, self.max_source_positions = input_dim, d_model, output_dim, max_source_positions
377
+ self.proj = nn.Linear(input_dim, d_model, bias=True) if input_dim != d_model else None
378
+ self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, d_model))
379
+ self.layers = nn.ModuleList([
380
+ OmniWhisperTransformerLayer(activation_function, d_model, encoder_attention_heads, encoder_ffn_dim, False, attn_type=attn_type)
381
+ for _ in range(encoder_layers)
382
+ ])
383
+ self.layer_norm = nn.LayerNorm(d_model)
384
+ self.out_proj = nn.Linear(d_model, output_dim, bias=True) if output_dim != d_model else None
385
+
386
+ def forward(self, input_features, input_length, output_hidden_states=False):
387
+ output_length = input_length.long()
388
+ hidden_states = self.proj(input_features.permute(0, 2, 1)).permute(0, 2, 1) if self.proj else input_features
389
+ hidden_states = hidden_states.permute(0, 2, 1)
390
+ bsz, tgt_len, _ = hidden_states.size()
391
+ pos_embed = self.positional_embedding[:tgt_len] if tgt_len < self.positional_embedding.shape[0] else self.positional_embedding
392
+ hidden_states = (hidden_states.to(torch.float32) + pos_embed).to(hidden_states.dtype)
393
+ attention_mask = get_sequence_mask(hidden_states, output_length)
394
+ all_hidden = () if output_hidden_states else None
395
+ for layer in self.layers:
396
+ if output_hidden_states:
397
+ all_hidden += (hidden_states,)
398
+ hidden_states = layer(hidden_states, output_length)
399
+ hidden_states = self.layer_norm(hidden_states)
400
+ if output_hidden_states:
401
+ all_hidden += (hidden_states,)
402
+ hidden_states = torch.where(attention_mask, hidden_states, 0).transpose(1, 2)
403
+ if self.out_proj:
404
+ hidden_states = self.out_proj(hidden_states.permute(0, 2, 1)).permute(0, 2, 1)
405
+ if not output_hidden_states:
406
+ return hidden_states, output_length
407
+ return hidden_states, output_length, all_hidden
408
+
409
+
410
+ # Note: The other helper classes like STFT, ISTFT, Vocos, VectorQuantize, etc.,
411
+ # would be placed here. For brevity, they are omitted but are required dependencies.
412
+ # Assuming they are defined in the same way as the user provided code.
413
+ # The code below will assume these classes are defined in the current scope.
414
+ # ... [Paste all other helper classes here] ...
415
+ class ISTFT(nn.Module):
416
+ def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
417
+ super().__init__()
418
+ if padding not in ["center", "same"]:
419
+ raise ValueError("Padding must be 'center' or 'same'.")
420
+ self.padding, self.n_fft, self.hop_length, self.win_length = padding, n_fft, hop_length, win_length
421
+ self.register_buffer("window", torch.hann_window(win_length))
422
+
423
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
424
+ if self.padding == "center":
425
+ return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
426
+ elif self.padding == "same":
427
+ pad = (self.win_length - self.hop_length) // 2
428
+ else:
429
+ raise ValueError("Padding must be 'center' or 'same'.")
430
+ B, N, T = spec.shape
431
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") * self.window[None, :, None]
432
+ output_size = (T - 1) * self.hop_length + self.win_length
433
+
434
+ y = F.fold(ifft, (1, output_size), (1, self.win_length), stride=(1, self.hop_length))[:, 0, 0, pad:-pad]
435
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
436
+ window_envelope = torch.nn.functional.fold(
437
+ window_sq,
438
+ output_size=(1, output_size),
439
+ kernel_size=(1, self.win_length),
440
+ stride=(1, self.hop_length),
441
+ ).squeeze()[pad:-pad]
442
+ assert (window_envelope > 1e-11).all()
443
+ return y / window_envelope
444
+
445
+
446
+ class FourierHead(nn.Module):
447
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
448
+ raise NotImplementedError("Subclasses must implement the forward method.")
449
+
450
+
451
+ class ISTFTHead(FourierHead):
452
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
453
+ super().__init__()
454
+ self.out = nn.Linear(dim, n_fft + 2)
455
+ self.istft = ISTFT(n_fft, hop_length, n_fft, padding)
456
+
457
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
458
+ x = self.out(x).transpose(1, 2)
459
+ mag, p = x.chunk(2, dim=1)
460
+ mag = torch.exp(mag).clip(max=1e2)
461
+ s = mag.float() * (torch.cos(p).float() + 1j * torch.sin(p).float())
462
+ return self.istft(s).to(x.dtype)
463
+
464
+
465
+ class AdaLayerNorm(nn.Module):
466
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
467
+ super().__init__()
468
+ self.eps, self.dim = eps, embedding_dim
469
+ self.scale = nn.Embedding(num_embeddings, embedding_dim)
470
+ self.shift = nn.Embedding(num_embeddings, embedding_dim)
471
+ torch.nn.init.ones_(self.scale.weight)
472
+ torch.nn.init.zeros_(self.shift.weight)
473
+
474
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
475
+ scale, shift = self.scale(cond_embedding_id), self.shift(cond_embedding_id)
476
+ x = F.layer_norm(x, (self.dim,), eps=self.eps)
477
+ return x * scale + shift
478
+
479
+
480
+ class ConvNeXtBlock(nn.Module):
481
+ def __init__(self, dim, intermediate_dim, layer_scale_init_value, adanorm_num_embeddings=None):
482
+ super().__init__()
483
+ self.dwconv = nn.Conv1d(dim, dim, 7, 1, 3, groups=dim)
484
+ self.adanorm = adanorm_num_embeddings is not None
485
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim) if self.adanorm else nn.LayerNorm(dim, eps=1e-6)
486
+ self.pwconv1 = nn.Linear(dim, intermediate_dim)
487
+ self.act = nn.GELU()
488
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
489
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) if layer_scale_init_value > 0 else None
490
+
491
+ def forward(self, x, cond_embedding_id=None):
492
+ res = x
493
+ x = self.dwconv(x).transpose(1, 2)
494
+ x = self.norm(x, cond_embedding_id) if self.adanorm else self.norm(x)
495
+ x = self.pwconv2(self.act(self.pwconv1(x)))
496
+ if self.gamma is not None:
497
+ x = self.gamma * x
498
+ x = res + x.transpose(1, 2)
499
+ return x
500
+
501
+
502
+ class Backbone(nn.Module):
503
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
504
+ raise NotImplementedError("Subclasses must implement the forward method.")
505
+
506
+
507
+ class VocosBackbone(Backbone):
508
+ def __init__(self, input_channels, dim, intermediate_dim, num_layers, layer_scale_init_value=None, adanorm_num_embeddings=None):
509
+ super().__init__()
510
+ self.input_channels, self.embed = input_channels, nn.Conv1d(input_channels, dim, 7, 1, 3)
511
+ self.adanorm = adanorm_num_embeddings is not None
512
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim) if self.adanorm else nn.LayerNorm(dim, eps=1e-6)
513
+ self.convnext = nn.ModuleList([ConvNeXtBlock(dim, intermediate_dim, layer_scale_init_value or 1/num_layers, adanorm_num_embeddings) for _ in range(num_layers)])
514
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
515
+ self.apply(self._init_weights)
516
+
517
+ def _init_weights(self, m):
518
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
519
+ nn.init.trunc_normal_(m.weight, std=0.02)
520
+ if m.bias is not None:
521
+ nn.init.constant_(m.bias, 0)
522
+
523
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
524
+ x = self.embed(x).transpose(1, 2)
525
+ x = self.norm(x, kwargs.get("bandwidth_id")) if self.adanorm else self.norm(x)
526
+ x = x.transpose(1, 2)
527
+ for block in self.convnext:
528
+ x = block(x, kwargs.get("bandwidth_id"))
529
+ return self.final_layer_norm(x.transpose(1, 2))
530
+
531
+
532
+ class Vocos(nn.Module):
533
+ def __init__(self, input_channels=128, dim=512, intermediate_dim=4096, num_layers=30, n_fft=640, hop_size=160, padding="same", adanorm_num_embeddings=None):
534
+ super().__init__()
535
+ self.backbone = VocosBackbone(input_channels, dim, intermediate_dim, num_layers, adanorm_num_embeddings=adanorm_num_embeddings)
536
+ self.head = ISTFTHead(dim, n_fft, hop_size, padding)
537
+ self.hop_size = hop_size
538
+
539
+ def forward(self, x, input_length):
540
+ x = self.backbone(x)
541
+ x = self.head(x)
542
+ return x[:, None, :], input_length * self.hop_size
543
+
544
+
545
+ def WNConv1d(*args, **kwargs):
546
+ return weight_norm(nn.Conv1d(*args, **kwargs))
547
+
548
+
549
+ def ema_inplace(moving_avg, new, decay):
550
+ moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay))
551
+
552
+
553
+ def sample_vectors(samples, num):
554
+ num_samples, device = samples.shape[0], samples.device
555
+ indices = torch.randperm(num_samples, device=device)[:num] if num_samples >= num else torch.randint(0, num_samples, (num,), device=device)
556
+ return samples[indices].float()
557
+
558
+
559
+ def kmeans(samples, num_clusters, num_iters=10):
560
+ dim, means = samples.shape[-1], sample_vectors(samples, num_clusters).float()
561
+ for _ in range(num_iters):
562
+ dists = -(samples.float().pow(2).sum(1, keepdim=True) - 2 * samples.float() @ means.t() + means.t().float().pow(2).sum(0, keepdim=True))
563
+ buckets = dists.max(dim=-1).indices
564
+ bins = torch.bincount(buckets, minlength=num_clusters)
565
+ zero_mask = bins == 0
566
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
567
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=torch.float32).scatter_add_(0, buckets.unsqueeze(1).expand(-1, dim), samples.float()) / bins_min_clamped[..., None]
568
+ means = torch.where(zero_mask[..., None], means, new_means)
569
+ dists = -(samples.float().pow(2).sum(1, keepdim=True) - 2 * samples.float() @ means.t() + means.t().float().pow(2).sum(0, keepdim=True))
570
+ return means, torch.bincount(dists.max(dim=-1).indices, minlength=num_clusters).float()
571
+
572
+
573
+ class VectorQuantize(nn.Module):
574
+ def __init__(self, input_dim, codebook_size, codebook_dim, commitment=1.0, decay=0.99, epsilon=1e-5, threshold_ema_dead=2, kmeans_init=True, kmeans_iters=10):
575
+ super().__init__()
576
+ self.input_dim, self.codebook_size, self.codebook_dim = input_dim, codebook_size, codebook_dim
577
+ self.commitment, self.decay, self.epsilon, self.threshold_ema_dead = commitment, decay, epsilon, threshold_ema_dead
578
+ self.kmeans_init, self.kmeans_iters = kmeans_init, kmeans_iters
579
+ self.in_project = WNConv1d(input_dim, codebook_dim, 1) if input_dim != codebook_dim else nn.Identity()
580
+ self.out_project = WNConv1d(codebook_dim, input_dim, 1) if codebook_dim != input_dim else nn.Identity()
581
+ self.register_buffer("codebook", torch.zeros(codebook_size, codebook_dim) if kmeans_init else torch.randn(codebook_size, codebook_dim))
582
+ self.register_buffer("inited", torch.tensor(not kmeans_init, dtype=torch.bool))
583
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
584
+ self.register_buffer("embed_avg", self.codebook.clone())
585
+
586
+ def ema_update(self, encodings, embed_onehot):
587
+ encodings, embed_onehot = encodings.float(), embed_onehot.float()
588
+ cluster_size_new, embed_sum = embed_onehot.sum(0), encodings.t() @ embed_onehot
589
+ if dist.is_initialized():
590
+ dist.all_reduce(cluster_size_new)
591
+ dist.all_reduce(embed_sum)
592
+ ema_inplace(self.cluster_size, cluster_size_new, self.decay)
593
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
594
+ cluster_size = (self.cluster_size + self.epsilon) / (self.cluster_size.sum() + self.codebook_size * self.epsilon) * self.cluster_size.sum()
595
+ self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1))
596
+
597
+ def replace_dead_codes(self, encodings):
598
+ if self.threshold_ema_dead == 0: return
599
+ dead_mask = self.cluster_size < self.threshold_ema_dead
600
+ if dead_mask.any():
601
+ samples = sample_vectors(encodings.float(), self.codebook_size) if not dist.is_initialized() or dist.get_rank() == 0 else torch.zeros_like(self.codebook)
602
+ if dist.is_initialized(): dist.broadcast(samples, src=0)
603
+ self.codebook[dead_mask] = samples[:dead_mask.sum()].to(self.codebook.dtype)
604
+
605
+ def init_codebook(self, encodings):
606
+ if self.inited.item(): return
607
+ if not dist.is_initialized() or dist.get_rank() == 0:
608
+ embed, cluster_sizes = kmeans(encodings.float(), self.codebook_size, self.kmeans_iters)
609
+ else:
610
+ embed, cluster_sizes = torch.zeros(self.codebook_size, self.codebook_dim, device=encodings.device), torch.zeros(self.codebook_size, device=encodings.device)
611
+ if dist.is_initialized():
612
+ dist.broadcast(embed, src=0)
613
+ dist.broadcast(cluster_sizes, src=0)
614
+ self.codebook.copy_(embed)
615
+ self.embed_avg.copy_(embed.clone())
616
+ self.cluster_size.copy_(cluster_sizes)
617
+ self.inited.fill_(True)
618
+
619
+ def forward(self, z):
620
+ z_e = self.in_project(z.float())
621
+ encodings = rearrange(z_e, "b d t -> (b t) d")
622
+ if self.kmeans_init and not self.inited.item(): self.init_codebook(encodings)
623
+ dist = encodings.pow(2).sum(1, keepdim=True) - 2 * encodings @ self.codebook.float().t() + self.codebook.float().pow(2).sum(1, keepdim=True).t()
624
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=z.size(0))
625
+ z_q = self.decode_code(indices)
626
+ commit_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) * self.commitment
627
+ if self.training and torch.is_grad_enabled():
628
+ self.ema_update(encodings, F.one_hot(indices.view(-1), self.codebook_size))
629
+ self.replace_dead_codes(encodings)
630
+ z_q = self.out_project(z_e + (z_q - z_e).detach())
631
+ return z_q, commit_loss, torch.tensor(0.0, device=z.device), indices, z_e
632
+
633
+ def decode_code(self, embed_id):
634
+ return F.embedding(embed_id, self.codebook.float()).transpose(1, 2)
635
+
636
+
637
+ class ResidualVQ(nn.Module):
638
+ def __init__(
639
+ self,
640
+ input_dim: int = 1280,
641
+ rvq_dim: int = None,
642
+ output_dim: int = None,
643
+ num_quantizers: int = 32,
644
+ codebook_size: int = 1024,
645
+ codebook_dim: int = 8,
646
+ quantizer_dropout: float = 0.5,
647
+ skip_rvq_ratio: float = 0.0,
648
+ vq_config: VectorQuantizerConfig = None,
649
+ **kwargs
650
+ ):
651
+ super().__init__()
652
+ self.input_dim, self.rvq_dim, self.output_dim = input_dim, rvq_dim, output_dim or input_dim
653
+ self.num_quantizers, self.codebook_size, self.codebook_dim = num_quantizers, codebook_size, codebook_dim
654
+ self.quantizer_dropout, self.skip_rvq_ratio = quantizer_dropout, skip_rvq_ratio
655
+ self.input_proj = WNConv1d(input_dim, rvq_dim, 1) if input_dim != rvq_dim else nn.Identity()
656
+ self.output_proj = WNConv1d(rvq_dim, self.output_dim, 1) if rvq_dim != self.output_dim else nn.Identity()
657
+ if vq_config is None:
658
+ vq_config = VectorQuantizerConfig()
659
+ quantizer_kwargs = asdict(vq_config)
660
+ self.quantizers = nn.ModuleList([VectorQuantize(rvq_dim, codebook_size, codebook_dim, **quantizer_kwargs, **kwargs) for _ in range(num_quantizers)])
661
+
662
+
663
+ def forward(self, z, input_length, n_quantizers: int = None):
664
+ z = self.input_proj(z)
665
+
666
+ with torch.autocast('cuda', enabled=False):
667
+ batch_size, _, max_time = z.shape
668
+ device = z.device
669
+ mask = torch.arange(max_time, device=device).expand(batch_size, max_time) < input_length.unsqueeze(1)
670
+
671
+ quantized_out = torch.zeros_like(z)
672
+ residual = z.clone().float()
673
+
674
+ all_commit_losses = []
675
+ all_indices = []
676
+ all_quantized = []
677
+
678
+ # --- Complexity Reduction Start ---
679
+ # 1. Extracted logic for determining quantizer numbers and skip mask
680
+ n_q_tensor = self._get_n_quantizers_tensor(batch_size, device, n_quantizers)
681
+ skip_mask = self._get_skip_mask(batch_size, device)
682
+ # --- Complexity Reduction End ---
683
+
684
+ max_q_to_run = self.num_quantizers if self.training else (n_quantizers or self.num_quantizers)
685
+
686
+ for i, quantizer in enumerate(self.quantizers[:max_q_to_run]):
687
+ # Create a mask for which batch items are active in this iteration
688
+ active_in_iteration_mask = (i < n_q_tensor)
689
+
690
+ # Skip quantization for items that are not active
691
+ if not active_in_iteration_mask.any():
692
+ # If no items are active, we can add placeholders and continue
693
+ # This branch is less common but handles the case where all items have dropped out
694
+ all_commit_losses.append(torch.tensor(0.0, device=device))
695
+ all_indices.append(torch.zeros(batch_size, max_time, dtype=torch.long, device=device))
696
+ all_quantized.append(torch.zeros_like(z))
697
+ continue
698
+
699
+ masked_residual = residual * mask.unsqueeze(1)
700
+
701
+ # --- Complexity Reduction Start ---
702
+ # 2. Extracted quantization step logic
703
+ z_q_i, commit_loss_i, indices_i = self._quantize_step(quantizer, masked_residual, skip_mask)
704
+ # --- Complexity Reduction End ---
705
+
706
+ # Create a mask for updating tensors (batch items active in this iteration AND within valid length)
707
+ update_mask = (active_in_iteration_mask.view(-1, 1, 1) & mask.unsqueeze(1))
708
+
709
+ quantized_out += z_q_i * update_mask
710
+ residual -= z_q_i * update_mask
711
+
712
+ # Calculate average commitment loss only for active items
713
+ commit_loss_i = commit_loss_i[active_in_iteration_mask].mean() if active_in_iteration_mask.any() else torch.tensor(0.0, device=device)
714
+
715
+ all_commit_losses.append(commit_loss_i)
716
+ all_indices.append(indices_i)
717
+ all_quantized.append(z_q_i)
718
+
719
+ # Pad the outputs if the loop was exited early (e.g., in eval mode with n_quantizers)
720
+ num_loops_done = len(all_commit_losses)
721
+ if num_loops_done < self.num_quantizers:
722
+ remaining = self.num_quantizers - num_loops_done
723
+ all_commit_losses.extend([torch.tensor(0.0, device=device)] * remaining)
724
+ all_indices.extend([torch.zeros(batch_size, max_time, dtype=torch.long, device=device)] * remaining)
725
+ all_quantized.extend([torch.zeros_like(z)] * remaining)
726
+
727
+
728
+ quantized_out = self.output_proj(quantized_out)
729
+ all_indices_tensor = torch.stack(all_indices)
730
+ all_commit_losses_tensor = torch.stack(all_commit_losses)
731
+ all_quantized_tensor = torch.stack(all_quantized)
732
+
733
+ return (
734
+ quantized_out,
735
+ all_indices_tensor,
736
+ all_commit_losses_tensor,
737
+ all_quantized_tensor,
738
+ input_length,
739
+ )
740
+
741
+ def decode_codes(self, codes):
742
+ nq, B, T = codes.shape
743
+ emb = torch.zeros(B, self.rvq_dim, T, device=codes.device, dtype=torch.float32)
744
+ for i, quantizer in enumerate(self.quantizers[:nq]):
745
+ emb += quantizer.decode_code(codes[i])
746
+ return self.output_proj(emb)
747
+
748
+ def _get_n_quantizers_tensor(self, batch_size: int, device: torch.device, n_quantizers_override: Optional[int] = None) -> torch.Tensor:
749
+ """
750
+ Determines the number of quantizers to use for each item in the batch,
751
+ applying dropout during training.
752
+ """
753
+ # If not training or dropout is disabled, use the override or default number of quantizers
754
+ is_training = self.training and torch.is_grad_enabled()
755
+ if not is_training or self.quantizer_dropout == 0:
756
+ num_q = n_quantizers_override or self.num_quantizers
757
+ return torch.full((batch_size,), num_q, dtype=torch.long, device=device)
758
+
759
+ # During training, apply quantizer dropout
760
+ n_q_tensor = torch.full((batch_size,), self.num_quantizers, device=device)
761
+ n_dropout = int(batch_size * self.quantizer_dropout)
762
+ if n_dropout > 0:
763
+ dropout_indices = torch.randperm(batch_size, device=device)[:n_dropout]
764
+ dropout_values = torch.randint(1, self.num_quantizers + 1, (n_dropout,), device=device)
765
+ n_q_tensor[dropout_indices] = dropout_values
766
+
767
+ return n_q_tensor
768
+
769
+ def _get_skip_mask(self, batch_size: int, device: torch.device) -> Optional[torch.Tensor]:
770
+ """Generates a mask for skipping RVQ during training if skip_rvq_ratio > 0."""
771
+ is_training = self.training and torch.is_grad_enabled()
772
+ if not is_training or self.skip_rvq_ratio <= 0:
773
+ return None
774
+
775
+ skip_mask = torch.rand(batch_size, device=device) < self.skip_rvq_ratio
776
+ # Ensure at least one sample is not skipped to avoid errors in modules like DDP
777
+ if skip_mask.all():
778
+ skip_mask[0] = False
779
+ return skip_mask
780
+
781
+ def _quantize_step(self, quantizer, residual, skip_mask):
782
+ """Helper to perform one step of quantization, handling the skip logic."""
783
+ # The main logic is for non-skipped samples
784
+ z_q_i, commit_loss_i, _, indices_i, z_e_i = quantizer(residual.float())
785
+
786
+ # If skipping is active, overwrite the results for the masked samples
787
+ if skip_mask is not None:
788
+ # For skipped samples, the "quantized" output is the residual itself
789
+ # and the loss is zero.
790
+ skip_mask_expanded = skip_mask.view(-1, 1, 1)
791
+ z_q_i = torch.where(skip_mask_expanded, residual, z_q_i)
792
+ commit_loss_i = torch.where(skip_mask, torch.zeros_like(commit_loss_i), commit_loss_i)
793
+
794
+ return z_q_i, commit_loss_i, indices_i
795
+
796
+
797
+
798
+ # ----------------------------------------------- #
799
+ # PreTrainedModel Base Class #
800
+ # ----------------------------------------------- #
801
+ class XYTokenizerPreTrainedModel(PreTrainedAudioTokenizerBase):
802
+ """
803
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
804
+ models.
805
+ """
806
+ config_class = XYTokenizerConfig
807
+ base_model_prefix = "xy_tokenizer"
808
+ main_input_name = "input_values"
809
+ _supports_grad_checkpointing = True
810
+
811
+ def _init_weights(self, module):
812
+ """Initialize the weights."""
813
+ if isinstance(module, (nn.Linear, nn.Conv1d, nn.ConvTranspose1d)):
814
+ module.weight.data.normal_(mean=0.0, std=0.02)
815
+ if module.bias is not None:
816
+ module.bias.data.zero_()
817
+ elif isinstance(module, nn.Embedding):
818
+ module.weight.data.normal_(mean=0.0, std=0.02)
819
+ if module.padding_idx is not None:
820
+ module.weight.data[module.padding_idx].zero_()
821
+
822
+ def _set_gradient_checkpointing(self, module, value=False):
823
+ if isinstance(module, (OmniAudioEncoder, OmniAudioDecoder, Transformer)):
824
+ module.gradient_checkpointing = value
825
+
826
+
827
+ # ----------------------------------------------- #
828
+ # Main Model Class #
829
+ # ----------------------------------------------- #
830
+ class XYTokenizerModel(XYTokenizerPreTrainedModel):
831
+ def __init__(self, config: XYTokenizerConfig):
832
+ super().__init__(config)
833
+ # Reconstruct the nested parameter dictionaries from the flat config
834
+ # This is a bit of a boilerplate but necessary to reuse the original module code.
835
+ # A more integrated approach would refactor the sub-modules to accept the flat config directly.
836
+ self.config = config
837
+
838
+ params = config.params
839
+ self.semantic_encoder = OmniAudioEncoder(**params['semantic_encoder_kwargs'])
840
+ self.semantic_encoder_adapter = Transformer(**params['semantic_encoder_adapter_kwargs'])
841
+ self.acoustic_encoder = OmniAudioEncoder(**params['acoustic_encoder_kwargs'])
842
+ self.pre_rvq_adapter = Transformer(**params['pre_rvq_adapter_kwargs'])
843
+ self.downsample = ResidualDownConv(**params['downsample_kwargs'])
844
+ self.quantizer = ResidualVQ(**params['quantizer_kwargs'])
845
+ self.post_rvq_adapter = Transformer(**params['post_rvq_adapter_kwargs'])
846
+ self.upsample = UpConv(**params['upsample_kwargs'])
847
+ self.acoustic_decoder = OmniAudioDecoder(**params['acoustic_decoder_kwargs'])
848
+ self.enhanced_vocos = Vocos(**params['vocos_kwargs'])
849
+ self.feature_extractor = params['feature_extractor_kwargs']
850
+ # Store some config values for easier access
851
+ self.encoder_downsample_rate = config.encoder_downsample_rate
852
+ self.nq = params['quantizer_kwargs']['num_quantizers']
853
+
854
+ # Initialize weights and apply final processing
855
+ self.post_init()
856
+
857
+ def _get_feat_extract_output_lengths(self, input_lengths: Optional[torch.Tensor]):
858
+ """
859
+ Computes the output lengths of the feature extractor.
860
+ """
861
+ def _get_out_len(in_len):
862
+ return (in_len - self.feature_extractor["n_fft"]) // self.feature_extractor["hop_length"] + 1
863
+
864
+ if input_lengths is None:
865
+ return None
866
+
867
+ return torch.tensor([_get_out_len(l) for l in input_lengths], device=self.device)
868
+
869
+ def scale_window_size(self, boundaries, scaling_factor):
870
+ scaling_range = []
871
+ scaling_boundaries = []
872
+ for left_boundary, right_boundary in boundaries:
873
+ scaling_left_boundary = left_boundary// scaling_factor
874
+ scaling_right_boundary = right_boundary // scaling_factor
875
+ scaling_range.append(scaling_right_boundary-scaling_left_boundary)
876
+ scaling_boundaries.append(slice(scaling_left_boundary, scaling_right_boundary))
877
+ return scaling_range, scaling_boundaries
878
+
879
+ @torch.inference_mode
880
+ def encode(
881
+ self,
882
+ features: Union[BatchFeature, ExtractorIterator],
883
+ n_quantizers: Optional[int] = None,
884
+ return_dict: Optional[bool] = True,
885
+ ) -> Union[XYTokenizerEncodeOutput, Tuple]:
886
+ r"""
887
+ Encodes the input audio waveform into discrete codes.
888
+
889
+ Args:
890
+ features (`BatchFeature` or `ExtractorIterator`):
891
+ A single batch of features or an iterator that yields batches of chunks for long audio files.
892
+ The iterator is expected to yield `BatchFeature` dicts which must contain a `sequence_ids`
893
+ tensor of shape `(batch_size,)` mapping each item in the chunk to its original sequence.
894
+ n_quantizers (`int`, *optional*):
895
+ The number of quantizers to use. If not specified, all quantizers are used.
896
+ return_dict (`bool`, *optional*):
897
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
898
+ Returns:
899
+ [`XYTokenizerEncodeOutput`] or `tuple(torch.FloatTensor)`
900
+ """
901
+ assert isinstance(features, (BatchFeature, ExtractorIterator))
902
+ # Handle single batch case
903
+ if isinstance(features, BatchFeature):
904
+ return self._encode(features, n_quantizers, return_dict)
905
+
906
+ # Handle streaming/chunked case
907
+ else:
908
+ # Use a dictionary to group chunks by their original sequence ID
909
+ encodings = defaultdict(lambda: {"zq": [], "codes": [], "length": 0})
910
+ commit_losses = []
911
+ total_frames = 0
912
+
913
+ # 1. Iterate through chunks and store intermediate results
914
+ for chunk_features in features:
915
+ # Always use return_dict=True for easier access to named outputs
916
+ chunk_output = self._encode(chunk_features, n_quantizers, return_dict=True)
917
+ valid_code_lengths, valid_code_ranges = self.scale_window_size(chunk_features["input_lengths"], self.encoder_downsample_rate)
918
+
919
+ # Accumulate weighted commit loss
920
+ chunk_length = chunk_output.codes_lengths.sum().item()
921
+ valid_chunk_length = sum(valid_code_lengths)
922
+ if chunk_output.commit_loss is not None and valid_chunk_length > 0:
923
+ commit_loss = chunk_output.commit_loss / chunk_length * valid_chunk_length
924
+ commit_losses.append((commit_loss.cpu(), valid_chunk_length))
925
+ total_frames += valid_chunk_length
926
+
927
+ # Group results by original sequence ID
928
+ for i, seq_id in enumerate(chunk_features["chunk_seq_no"].tolist()):
929
+ valid_code_range = valid_code_ranges[i]
930
+ if valid_code_range.stop > 0:
931
+ encodings[seq_id]["zq"].append(chunk_output.quantized_representation[i:i+1, :, valid_code_range])
932
+ encodings[seq_id]["codes"].append(chunk_output.audio_codes[:, i:i+1, valid_code_range])
933
+ # Add the valid length of this chunk to the total for this sequence
934
+ encodings[seq_id]["length"] += valid_code_lengths[i]
935
+
936
+ final_outputs = []
937
+ for seq_id, seq_data in encodings.items():
938
+ final_outputs.append({
939
+ "zq": torch.cat(seq_data["zq"], dim=2),
940
+ "codes": torch.cat(seq_data["codes"], dim=2),
941
+ "length": seq_data["length"]
942
+ })
943
+
944
+ # 3. Pad all sequences to the same length and stack into a batch
945
+ max_len = max(seq["zq"].shape[2] for seq in final_outputs)
946
+
947
+ batch_zq = []
948
+ batch_codes = []
949
+ batch_lengths = []
950
+
951
+ for seq in final_outputs:
952
+ pad_amount = max_len - seq["zq"].shape[2]
953
+ # Pad on the right side of the last dimension (time)
954
+ padded_zq = F.pad(seq["zq"], (0, pad_amount))
955
+ padded_codes = F.pad(seq["codes"], (0, pad_amount))
956
+
957
+ batch_zq.append(padded_zq)
958
+ batch_codes.append(padded_codes)
959
+ batch_lengths.append(seq["length"])
960
+
961
+ # Stack the list of tensors into a single batch tensor
962
+ quantized_representation = torch.cat(batch_zq, dim=0)
963
+ audio_codes = torch.cat(batch_codes, dim=0)
964
+ codes_lengths = torch.tensor(batch_lengths, dtype=torch.long, device=self.device)
965
+
966
+ # 4. Calculate final commit loss
967
+ if total_frames > 0:
968
+ # Weighted average of commit losses
969
+ commit_loss = sum(loss * length for loss, length in commit_losses) / total_frames
970
+ commit_loss = commit_loss.to(self.device)
971
+ else:
972
+ commit_loss = torch.tensor(0.0, device=self.device)
973
+
974
+ if not return_dict:
975
+ return (quantized_representation, audio_codes, codes_lengths, commit_loss)
976
+
977
+ return XYTokenizerEncodeOutput(
978
+ quantized_representation=quantized_representation,
979
+ audio_codes=audio_codes,
980
+ codes_lengths=codes_lengths,
981
+ commit_loss=commit_loss,
982
+ overlap_seconds=features.overlap_seconds,
983
+ )
984
+
985
+ def _encode(
986
+ self,
987
+ features: BatchFeature,
988
+ n_quantizers: Optional[int] = None,
989
+ return_dict: Optional[bool] = True,
990
+ ) -> Union[XYTokenizerEncodeOutput, Tuple]:
991
+ input_mel = features['input_features'].to(self.device, dtype=self.dtype)
992
+ mel_attention_mask = features['attention_mask'].to(self.device)
993
+ mel_output_length = mel_attention_mask.sum(dim=-1).long()
994
+
995
+ # --- Encoder Path ---
996
+ semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder(input_mel, mel_output_length)
997
+ semantic_adapter_output, _ = self.semantic_encoder_adapter(semantic_encoder_output, semantic_encoder_output_length)
998
+ acoustic_encoder_output, acoustic_encoder_output_length = self.acoustic_encoder(input_mel, mel_output_length)
999
+
1000
+ concated_channel = torch.cat([semantic_adapter_output, acoustic_encoder_output], dim=1)
1001
+
1002
+ pre_rvq_adapter_output, pre_rvq_adapter_output_length = self.pre_rvq_adapter(concated_channel, acoustic_encoder_output_length)
1003
+ downsample_output, downsample_output_length = self.downsample(pre_rvq_adapter_output, pre_rvq_adapter_output_length)
1004
+
1005
+ n_quantizers = n_quantizers or self.quantizer.num_quantizers
1006
+ zq, codes, vq_loss, _, quantizer_output_length = self.quantizer(downsample_output, downsample_output_length, n_quantizers=n_quantizers)
1007
+
1008
+ if not return_dict:
1009
+ return (zq, codes, quantizer_output_length, vq_loss)
1010
+
1011
+ return XYTokenizerEncodeOutput(
1012
+ quantized_representation=zq,
1013
+ audio_codes=codes,
1014
+ codes_lengths=quantizer_output_length,
1015
+ commit_loss=vq_loss.mean()
1016
+ )
1017
+
1018
+ @torch.inference_mode
1019
+ def decode(
1020
+ self,
1021
+ audio_codes: Union[torch.Tensor, XYTokenizerEncodeOutput],
1022
+ overlap_seconds: int = 10,
1023
+ return_dict: Optional[bool] = True,
1024
+ ) -> Union[XYTokenizerDecodeOutput, Tuple]:
1025
+ r"""
1026
+ Decodes discrete codes back into an audio waveform.
1027
+
1028
+ Args:
1029
+ audio_codes (`torch.LongTensor` of shape `(num_codebooks, batch_size, sequence_length)`):
1030
+ The discrete codes from the quantizer for each codebook.
1031
+ codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1032
+ The valid length of each sequence in `audio_codes`. If not provided, it's assumed to be the full length.
1033
+ return_dict (`bool`, *optional*):
1034
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1035
+ Returns:
1036
+ [`XYTokenizerDecodeOutput`] or `tuple(torch.FloatTensor)`
1037
+ """
1038
+ assert not isinstance(audio_codes, tuple), "try to set param `return_dict=True` for `codec.encode()` function"
1039
+ assert isinstance(audio_codes, (torch.Tensor, XYTokenizerEncodeOutput)), \
1040
+ "only accept `torch.Tensor` or `XYTokenizerEncodeOutput` for `codec.decode()` function"
1041
+ if isinstance(audio_codes, XYTokenizerEncodeOutput):
1042
+ audio_codes = audio_codes.audio_codes
1043
+ if hasattr(audio_codes, "overlap_seconds"):
1044
+ overlap_seconds = audio_codes.overlap_seconds
1045
+ if overlap_seconds is None:
1046
+ overlap_seconds = 0
1047
+ chunk_length = self.feature_extractor["chunk_length"]
1048
+ duration_seconds = chunk_length - overlap_seconds
1049
+ chunk_code_length = int(chunk_length * self.feature_extractor["sampling_rate"] // self.config.encoder_downsample_rate) # Maximum code length per chunk
1050
+ duration_code_length = int(duration_seconds * self.feature_extractor["sampling_rate"] // self.config.encoder_downsample_rate) # Valid code length per chunk
1051
+ duration_wav_length = duration_code_length * self.config.decoder_upsample_rate # Valid waveform length per chunk
1052
+
1053
+ # Get maximum code length
1054
+ batch_size = audio_codes.shape[1]
1055
+ codes_list = [audio_codes[:, i, :] for i in range(batch_size)]
1056
+ max_code_length = max(codes.shape[-1] for codes in codes_list)
1057
+ batch_size = len(codes_list)
1058
+ codes_tensor = torch.zeros(self.nq, batch_size, max_code_length, device=self.device, dtype=torch.long)
1059
+ code_lengths = torch.zeros(batch_size, dtype=torch.long, device=self.device)
1060
+ for i, codes in enumerate(codes_list):
1061
+ codes_tensor[:, i, :codes.shape[-1]] = codes.to(self.device)
1062
+ code_lengths[i] = codes.shape[-1] # (B,)
1063
+
1064
+ # Calculate number of chunks needed
1065
+ max_chunks = (max_code_length + duration_code_length - 1) // duration_code_length
1066
+ wav_list = []
1067
+
1068
+ # Process the entire batch in chunks
1069
+ for chunk_idx in range(max_chunks):
1070
+ start = chunk_idx * duration_code_length
1071
+ end = min(start + chunk_code_length, max_code_length)
1072
+ chunk_codes = codes_tensor[:, :, start:end] # (nq, B, T')
1073
+ chunk_code_lengths = torch.clamp(code_lengths - start, 0, end - start) # (B,)
1074
+
1075
+ # Skip empty chunks
1076
+ if chunk_code_lengths.max() == 0:
1077
+ continue
1078
+
1079
+ # Decode
1080
+ result = self._decode(chunk_codes, chunk_code_lengths) # {"y": (B, 1, T'), "output_length": (B,)}
1081
+ chunk_wav = result["audio_values"] # (B, 1, T')
1082
+ chunk_wav_lengths = result["output_length"] # (B,)
1083
+
1084
+ # Extract valid portion
1085
+ valid_wav_lengths = torch.clamp(chunk_wav_lengths, 0, duration_wav_length) # (B,)
1086
+ valid_chunk_wav = torch.zeros(batch_size, 1, duration_wav_length, device=self.device)
1087
+ for b in range(batch_size):
1088
+ if valid_wav_lengths[b] > 0:
1089
+ valid_chunk_wav[b, :, :valid_wav_lengths[b]] = chunk_wav[b, :, :valid_wav_lengths[b]] # (B, 1, valid_wav_length)
1090
+
1091
+ wav_list.append(valid_chunk_wav) # (B, 1, valid_wav_length)
1092
+
1093
+ # Concatenate all chunks
1094
+ if wav_list:
1095
+ wav_tensor = torch.cat(wav_list, dim=-1) # (B, 1, T_total)
1096
+ syn_wav_list = [wav_tensor[i, :, :code_lengths[i] * self.config.decoder_upsample_rate] for i in range(batch_size)] # B * (1, T,)
1097
+ else:
1098
+ syn_wav_list = [torch.zeros(1, 0, device=self.device) for _ in range(batch_size)] # B * (1, 0,)
1099
+
1100
+ if not return_dict:
1101
+ return (syn_wav_list,)
1102
+
1103
+ return XYTokenizerDecodeOutput(
1104
+ audio_values=syn_wav_list
1105
+ )
1106
+
1107
+ def _decode(
1108
+ self,
1109
+ audio_codes: torch.Tensor,
1110
+ codes_lengths: Optional[torch.Tensor] = None,
1111
+ return_dict: Optional[bool] = True,
1112
+ ) -> Union[XYTokenizerDecodeOutput, Tuple]:
1113
+ r"""
1114
+ Decodes discrete codes back into an audio waveform.
1115
+
1116
+ Args:
1117
+ audio_codes (`torch.LongTensor` of shape `(num_codebooks, batch_size, sequence_length)`):
1118
+ The discrete codes from the quantizer for each codebook.
1119
+ codes_lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1120
+ The valid length of each sequence in `audio_codes`. If not provided, it's assumed to be the full length.
1121
+ return_dict (`bool`, *optional*):
1122
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1123
+ Returns:
1124
+ [`XYTokenizerDecodeOutput`] or `tuple(torch.FloatTensor)`
1125
+ """
1126
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1127
+
1128
+ if codes_lengths is None:
1129
+ codes_lengths = torch.full((audio_codes.shape[1],), audio_codes.shape[2], device=self.device)
1130
+
1131
+ # --- Decoder Path ---
1132
+ zq = self.quantizer.decode_codes(audio_codes)
1133
+
1134
+ post_rvq_adapter_output, post_rvq_adapter_output_length = self.post_rvq_adapter(zq, codes_lengths)
1135
+ upsample_output, upsample_output_length = self.upsample(post_rvq_adapter_output, post_rvq_adapter_output_length)
1136
+ acoustic_decoder_output, acoustic_decoder_output_length = self.acoustic_decoder(upsample_output, upsample_output_length)
1137
+ y, vocos_output_length = self.enhanced_vocos(acoustic_decoder_output, acoustic_decoder_output_length)
1138
+
1139
+ if not return_dict:
1140
+ return (y, vocos_output_length)
1141
+
1142
+ return XYTokenizerDecodeOutput(
1143
+ audio_values=y,
1144
+ output_length=vocos_output_length
1145
+ )
1146
+
1147
+ def forward(
1148
+ self,
1149
+ input_values: torch.Tensor,
1150
+ attention_mask: Optional[torch.Tensor] = None,
1151
+ n_quantizers: Optional[int] = None,
1152
+ return_dict: Optional[bool] = True,
1153
+ ) -> Union[XYTokenizerModelOutput, Tuple]:
1154
+ r"""
1155
+ The forward method that handles the full encoding and decoding process.
1156
+
1157
+ Args:
1158
+ input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
1159
+ Float values of the input audio waveform.
1160
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1161
+ Mask to avoid performing attention on padding token indices.
1162
+ n_quantizers (`int`, *optional*):
1163
+ The number of quantizers to use for encoding. If not specified, all quantizers are used.
1164
+ return_dict (`bool`, *optional*):
1165
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1166
+
1167
+ Examples:
1168
+
1169
+ ```python
1170
+ >>> from transformers import AutoModel, AutoFeatureExtractor
1171
+ >>> from datasets import load_dataset, Audio
1172
+ >>> import torch
1173
+
1174
+ >>> # This is a placeholder model name, replace with the actual one on the Hub
1175
+ >>> model_id = "your-namespace/xy-tokenizer-model"
1176
+ >>> model = AutoModel.from_pretrained(model_id)
1177
+ >>> # The feature extractor config is part of the model config, so it can be loaded this way
1178
+ >>> feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
1179
+
1180
+ >>> # Load a dummy audio dataset
1181
+ >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1182
+ >>> audio_sample = ds[0]["audio"]["array"]
1183
+ >>> sampling_rate = ds[0]["audio"]["sampling_rate"]
1184
+
1185
+ >>> # Process audio
1186
+ >>> inputs = feature_extractor(audio_sample, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
1187
+
1188
+ >>> # Encode to get codes
1189
+ >>> with torch.no_grad():
1190
+ ... encoder_output = model.encode(inputs["input_values"], attention_mask=inputs["attention_mask"])
1191
+ ... audio_codes = encoder_output.audio_codes
1192
+
1193
+ >>> # Decode from codes
1194
+ >>> with torch.no_grad():
1195
+ ... decoder_output = model.decode(audio_codes)
1196
+ ... reconstructed_audio = decoder_output.audio_values
1197
+
1198
+ >>> # Full forward pass
1199
+ >>> with torch.no_grad():
1200
+ ... model_output = model(**inputs)
1201
+ ... reconstructed_audio_fwd = model_output.audio_values
1202
+
1203
+ >>> print(reconstructed_audio.shape)
1204
+ torch.Size([1, 1, 147200])
1205
+ >>> print(torch.allclose(reconstructed_audio, reconstructed_audio_fwd))
1206
+ True
1207
+ ```
1208
+
1209
+ Returns:
1210
+ [`XYTokenizerModelOutput`] or `tuple(torch.FloatTensor)`
1211
+ """
1212
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1213
+
1214
+ encoder_outputs = self.encode(
1215
+ input_values=input_values,
1216
+ attention_mask=attention_mask,
1217
+ n_quantizers=n_quantizers,
1218
+ return_dict=True
1219
+ )
1220
+
1221
+ decoder_outputs = self.decode(
1222
+ audio_codes=encoder_outputs,
1223
+ return_dict=True
1224
+ )
1225
+
1226
+ if not return_dict:
1227
+ return (
1228
+ decoder_outputs.audio_values,
1229
+ decoder_outputs.output_length,
1230
+ encoder_outputs.quantized_representation,
1231
+ encoder_outputs.audio_codes,
1232
+ encoder_outputs.codes_lengths,
1233
+ encoder_outputs.commit_loss
1234
+ )
1235
+
1236
+ return XYTokenizerModelOutput(
1237
+ audio_values=decoder_outputs.audio_values,
1238
+ output_length=decoder_outputs.output_length,
1239
+ quantized_representation=encoder_outputs.quantized_representation,
1240
+ audio_codes=encoder_outputs.audio_codes,
1241
+ codes_lengths=encoder_outputs.codes_lengths,
1242
+ commit_loss=encoder_outputs.commit_loss
1243
+ )
preprocessor_config.json CHANGED
@@ -9,5 +9,6 @@
9
  "padding_value": 0.0,
10
  "sampling_rate": 16000,
11
  "return_attention_mask": true,
12
- "return_tensors": "pt"
 
13
  }
 
9
  "padding_value": 0.0,
10
  "sampling_rate": 16000,
11
  "return_attention_mask": true,
12
+ "return_tensors": "pt",
13
+ "overlap_side": "both"
14
  }