mazesmazes commited on
Commit
10b2bb7
·
verified ·
1 Parent(s): 87f0e47

Training in progress - step 500

Browse files
asr_config.py CHANGED
@@ -22,10 +22,15 @@ class ASRConfig(transformers.PretrainedConfig):
22
  projector_init_std: float = 0.02,
23
  projector_pool_stride: int = 2,
24
  projector_hidden_dim: Optional[int] = None,
25
- projector_type: str = "moe", # "moe", "swiglu", or "residual"
26
  projector_num_layers: int = 2, # Number of layers (for residual projector)
27
  projector_dropout: float = 0.05, # Dropout rate for projector layers
28
  projector_input_noise: float = 0.02, # Input noise for projector
 
 
 
 
 
29
  label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
30
  inference_diversity_penalty: float = 0.0,
31
  inference_warmup_tokens: int = 10,
@@ -45,9 +50,10 @@ class ASRConfig(transformers.PretrainedConfig):
45
  # Set default generation parameters
46
  generation_defaults = {
47
  "num_beams": 1,
48
- "max_new_tokens": 64,
49
- "min_new_tokens": 1,
50
  "do_sample": False,
 
51
  "repetition_penalty": 1.0,
52
  "length_penalty": 1.0,
53
  "no_repeat_ngram_size": 0,
@@ -73,9 +79,29 @@ class ASRConfig(transformers.PretrainedConfig):
73
  self.projector_num_layers = projector_num_layers
74
  self.projector_dropout = projector_dropout
75
  self.projector_input_noise = projector_input_noise
 
 
 
 
 
76
  self.label_smoothing = label_smoothing
77
  self.inference_diversity_penalty = inference_diversity_penalty
78
  self.inference_warmup_tokens = inference_warmup_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  if "audio_config" not in kwargs:
80
  self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
81
  # Override dtype to match model_dtype
 
22
  projector_init_std: float = 0.02,
23
  projector_pool_stride: int = 2,
24
  projector_hidden_dim: Optional[int] = None,
25
+ projector_type: str = "moe", # "moe", "swiglu", "residual", "shared_moe"
26
  projector_num_layers: int = 2, # Number of layers (for residual projector)
27
  projector_dropout: float = 0.05, # Dropout rate for projector layers
28
  projector_input_noise: float = 0.02, # Input noise for projector
29
+ # MoE-specific configuration
30
+ num_experts: int = 4, # Number of experts in MoE projectors
31
+ num_experts_per_tok: int = 2, # Top-k experts per token
32
+ router_aux_loss_coef: float = 0.01, # Auxiliary loss coefficient for load balancing
33
+ use_specaugment: bool = True, # Apply SpecAugment during training
34
  label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
35
  inference_diversity_penalty: float = 0.0,
36
  inference_warmup_tokens: int = 10,
 
50
  # Set default generation parameters
51
  generation_defaults = {
52
  "num_beams": 1,
53
+ "max_new_tokens": 256,
54
+ "min_new_tokens": 0,
55
  "do_sample": False,
56
+ "temperature": 0.1,
57
  "repetition_penalty": 1.0,
58
  "length_penalty": 1.0,
59
  "no_repeat_ngram_size": 0,
 
79
  self.projector_num_layers = projector_num_layers
80
  self.projector_dropout = projector_dropout
81
  self.projector_input_noise = projector_input_noise
82
+ # MoE-specific configuration
83
+ self.num_experts = num_experts
84
+ self.num_experts_per_tok = num_experts_per_tok
85
+ self.router_aux_loss_coef = router_aux_loss_coef
86
+ self.use_specaugment = use_specaugment
87
  self.label_smoothing = label_smoothing
88
  self.inference_diversity_penalty = inference_diversity_penalty
89
  self.inference_warmup_tokens = inference_warmup_tokens
90
+
91
+ # Generation parameters (use explicit value if provided, else use default)
92
+ self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
93
+ self.max_new_tokens = max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
94
+ self.min_new_tokens = min_new_tokens if min_new_tokens is not None else generation_defaults["min_new_tokens"]
95
+ self.do_sample = do_sample if do_sample is not None else generation_defaults["do_sample"]
96
+ self.repetition_penalty = repetition_penalty if repetition_penalty is not None else generation_defaults["repetition_penalty"]
97
+ self.length_penalty = length_penalty if length_penalty is not None else generation_defaults["length_penalty"]
98
+ self.no_repeat_ngram_size = no_repeat_ngram_size if no_repeat_ngram_size is not None else generation_defaults["no_repeat_ngram_size"]
99
+ self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
100
+ self.temperature = temperature if temperature is not None else generation_defaults["temperature"]
101
+ self.top_k = top_k
102
+ self.top_p = top_p
103
+ self.early_stopping = early_stopping
104
+
105
  if "audio_config" not in kwargs:
106
  self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
107
  # Override dtype to match model_dtype
asr_modeling.py CHANGED
@@ -16,10 +16,26 @@ from transformers.models.whisper.modeling_whisper import (
16
  _compute_mask_indices,
17
  )
18
 
19
- from .asr_config import ASRConfig
20
- from .moe_projector import MoEAudioProjector
21
- from .residual_projector import ResidualAudioProjector
22
- from .swiglu_projector import AudioProjector
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  class ASRModel(PreTrainedModel):
@@ -173,12 +189,13 @@ class ASRModel(PreTrainedModel):
173
 
174
  # Select projector type based on config
175
  projector_type = getattr(config, "projector_type", "moe")
176
- if projector_type == "swiglu":
177
- projector = AudioProjector(config)
178
- elif projector_type == "residual":
179
- projector = ResidualAudioProjector(config)
180
- else: # default to "moe"
181
- projector = MoEAudioProjector(config)
 
182
 
183
  # Move projector to same device as language model (important when using quantization)
184
  device = next(self.language_model.parameters()).device
@@ -234,7 +251,10 @@ class ASRModel(PreTrainedModel):
234
 
235
  def get_processor(self):
236
  """Get the processor for this model."""
237
- from .asr_processing import ASRProcessor
 
 
 
238
 
239
  return ASRProcessor(feature_extractor=self.feature_extractor, tokenizer=self.tokenizer)
240
 
@@ -247,15 +267,7 @@ class ASRModel(PreTrainedModel):
247
  input_features: torch.Tensor,
248
  attention_mask: Optional[torch.Tensor] = None,
249
  ) -> torch.Tensor:
250
- """
251
- Apply SpecAugment masking to input features.
252
 
253
- Uses Whisper's default parameters:
254
- - mask_time_prob: 0.05 (5% of time steps)
255
- - mask_time_length: 10 frames
256
- - mask_feature_prob: 0.0 (disabled by default)
257
- - mask_feature_length: 10 features
258
- """
259
  if not getattr(self.config, "use_specaugment", False):
260
  return input_features
261
 
@@ -459,6 +471,12 @@ class ASRModel(PreTrainedModel):
459
  label_smoothing=getattr(self.config, "label_smoothing", 0.0),
460
  )
461
 
 
 
 
 
 
 
462
  return CausalLMOutputWithPast(
463
  loss=loss,
464
  logits=outputs.logits,
@@ -521,18 +539,38 @@ class ASRModel(PreTrainedModel):
521
  prompt_ids, audio_embeds, audio_mask=audio_mask
522
  )
523
 
524
- # Set generation defaults
525
- generate_kwargs.setdefault("max_new_tokens", getattr(self.config, "max_new_tokens", 128))
526
- generate_kwargs.setdefault("use_cache", True)
 
 
 
 
 
 
 
 
 
 
527
  generate_kwargs.setdefault(
528
  "eos_token_id", self.tokenizer.convert_tokens_to_ids("<|im_end|>")
529
  )
530
  generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
531
 
 
 
 
 
 
 
 
 
 
532
  # Generate (type ignore needed as generate() has complex return type)
533
  # Note: When using inputs_embeds, generate() returns only new tokens
534
  # (no placeholder positions for input embeddings), so no stripping needed
535
  output = self.language_model.generate( # type: ignore[operator]
 
536
  inputs_embeds=inputs_embeds,
537
  attention_mask=attention_mask,
538
  **generate_kwargs,
@@ -576,6 +614,17 @@ class ASRModel(PreTrainedModel):
576
  src_dir = PathlibPath(__file__).parent
577
  for asr_file in src_dir.glob("asr_*.py"):
578
  shutil.copy(asr_file, save_dir / asr_file.name)
 
 
 
 
 
 
 
 
 
 
 
579
 
580
 
581
  # Register with transformers Auto classes
 
16
  _compute_mask_indices,
17
  )
18
 
19
+ try:
20
+ from .asr_config import ASRConfig
21
+ from .moe_projector import MoEAudioProjector
22
+ from .residual_projector import ResidualAudioProjector
23
+ from .swiglu_projector import AudioProjector
24
+ from .shared_moe_projector import SharedMoEAudioProjector
25
+ except ImportError:
26
+ from asr_config import ASRConfig # type: ignore[no-redef]
27
+ from moe_projector import MoEAudioProjector # type: ignore[no-redef]
28
+ from residual_projector import ResidualAudioProjector # type: ignore[no-redef]
29
+ from swiglu_projector import AudioProjector # type: ignore[no-redef]
30
+ from shared_moe_projector import SharedMoEAudioProjector # type: ignore[no-redef]
31
+
32
+ # Map projector type names to classes
33
+ PROJECTOR_CLASSES = {
34
+ "swiglu": AudioProjector,
35
+ "residual": ResidualAudioProjector,
36
+ "moe": MoEAudioProjector,
37
+ "shared_moe": SharedMoEAudioProjector,
38
+ }
39
 
40
 
41
  class ASRModel(PreTrainedModel):
 
189
 
190
  # Select projector type based on config
191
  projector_type = getattr(config, "projector_type", "moe")
192
+ projector_class = PROJECTOR_CLASSES.get(projector_type)
193
+ if projector_class is None:
194
+ raise ValueError(
195
+ f"Unknown projector_type: {projector_type}. "
196
+ f"Valid options: {list(PROJECTOR_CLASSES.keys())}"
197
+ )
198
+ projector = projector_class(config)
199
 
200
  # Move projector to same device as language model (important when using quantization)
201
  device = next(self.language_model.parameters()).device
 
251
 
252
  def get_processor(self):
253
  """Get the processor for this model."""
254
+ try:
255
+ from .asr_processing import ASRProcessor
256
+ except ImportError:
257
+ from asr_processing import ASRProcessor # type: ignore[no-redef]
258
 
259
  return ASRProcessor(feature_extractor=self.feature_extractor, tokenizer=self.tokenizer)
260
 
 
267
  input_features: torch.Tensor,
268
  attention_mask: Optional[torch.Tensor] = None,
269
  ) -> torch.Tensor:
 
 
270
 
 
 
 
 
 
 
271
  if not getattr(self.config, "use_specaugment", False):
272
  return input_features
273
 
 
471
  label_smoothing=getattr(self.config, "label_smoothing", 0.0),
472
  )
473
 
474
+ # Add auxiliary loss from MoE projectors if available
475
+ if hasattr(self.projector, "get_aux_loss"):
476
+ aux_loss = self.projector.get_aux_loss()
477
+ if aux_loss is not None and aux_loss.numel() > 0:
478
+ loss = loss + aux_loss.to(loss.device)
479
+
480
  return CausalLMOutputWithPast(
481
  loss=loss,
482
  logits=outputs.logits,
 
539
  prompt_ids, audio_embeds, audio_mask=audio_mask
540
  )
541
 
542
+ # Set generation defaults from config
543
+ generate_kwargs.setdefault("max_new_tokens", self.config.max_new_tokens)
544
+ generate_kwargs.setdefault("num_beams", self.config.num_beams)
545
+ generate_kwargs.setdefault("do_sample", self.config.do_sample)
546
+ generate_kwargs.setdefault("use_cache", self.config.use_cache)
547
+ generate_kwargs.setdefault("length_penalty", self.config.length_penalty)
548
+ generate_kwargs.setdefault("repetition_penalty", self.config.repetition_penalty)
549
+ generate_kwargs.setdefault("no_repeat_ngram_size", self.config.no_repeat_ngram_size)
550
+ generate_kwargs.setdefault("temperature", self.config.temperature)
551
+ if self.config.top_k is not None:
552
+ generate_kwargs.setdefault("top_k", self.config.top_k)
553
+ if self.config.top_p is not None:
554
+ generate_kwargs.setdefault("top_p", self.config.top_p)
555
  generate_kwargs.setdefault(
556
  "eos_token_id", self.tokenizer.convert_tokens_to_ids("<|im_end|>")
557
  )
558
  generate_kwargs.setdefault("pad_token_id", self.tokenizer.pad_token_id)
559
 
560
+ # Create dummy input_ids matching inputs_embeds length for repetition penalty tracking
561
+ # Use pad_token_id as placeholder since the actual tokens don't matter for penalty calc
562
+ dummy_input_ids = torch.full(
563
+ (inputs_embeds.shape[0], inputs_embeds.shape[1]),
564
+ self.tokenizer.pad_token_id,
565
+ dtype=torch.long,
566
+ device=device,
567
+ )
568
+
569
  # Generate (type ignore needed as generate() has complex return type)
570
  # Note: When using inputs_embeds, generate() returns only new tokens
571
  # (no placeholder positions for input embeddings), so no stripping needed
572
  output = self.language_model.generate( # type: ignore[operator]
573
+ input_ids=dummy_input_ids,
574
  inputs_embeds=inputs_embeds,
575
  attention_mask=attention_mask,
576
  **generate_kwargs,
 
614
  src_dir = PathlibPath(__file__).parent
615
  for asr_file in src_dir.glob("asr_*.py"):
616
  shutil.copy(asr_file, save_dir / asr_file.name)
617
+ # Copy projector files
618
+ projector_files = [
619
+ "moe_projector.py",
620
+ "residual_projector.py",
621
+ "swiglu_projector.py",
622
+ "shared_moe_projector.py",
623
+ ]
624
+ for projector_file in projector_files:
625
+ src_path = src_dir / projector_file
626
+ if src_path.exists():
627
+ shutil.copy(src_path, save_dir / projector_file)
628
 
629
 
630
  # Register with transformers Auto classes
asr_pipeline.py CHANGED
@@ -4,9 +4,11 @@ from typing import Any
4
 
5
  import torch
6
  import transformers
7
- from truecase import get_true_case
8
 
9
- from .asr_modeling import ASRModel
 
 
 
10
 
11
 
12
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
@@ -27,11 +29,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
27
  model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
28
  )
29
 
30
- # Initialize text normalizer (WhisperTokenizer has the normalize method we need)
31
- from transformers import WhisperTokenizer
32
-
33
- self.text_normalizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny")
34
-
35
  def __call__(self, inputs, **kwargs):
36
  generate_kwargs = {}
37
  generate_keys = [
@@ -89,8 +86,6 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
89
  all_tokens.extend(tokens.tolist() if torch.is_tensor(tokens) else tokens)
90
 
91
  text = self.tokenizer.decode(all_tokens, skip_special_tokens=True).strip()
92
- text = self.text_normalizer.normalize(text)
93
- text = get_true_case(text)
94
 
95
  return {"text": text}
96
 
@@ -105,7 +100,13 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
105
  if "bytes" in inputs:
106
  inputs = self._decode_audio_bytes(inputs["bytes"])
107
  elif "array" in inputs:
108
- inputs = {"raw": inputs["array"], "sampling_rate": inputs["sampling_rate"]}
 
 
 
 
 
 
109
  elif hasattr(inputs, "array") and hasattr(inputs, "sampling_rate"):
110
  inputs = {"raw": inputs.array, "sampling_rate": inputs.sampling_rate}
111
  elif hasattr(inputs, "__array__") and not isinstance(inputs, (dict, bytes, str)):
@@ -116,6 +117,23 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
116
  "sampling_rate": self.model.config.audio_sample_rate,
117
  }
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  return super().preprocess(inputs, **preprocess_params)
120
 
121
  def _decode_audio_bytes(self, wav_bytes: bytes) -> dict[str, Any]:
@@ -225,7 +243,5 @@ class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
225
  tokens = tokens[0]
226
 
227
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
228
- text = self.text_normalizer.normalize(text)
229
- text = get_true_case(text)
230
 
231
  return {"text": text}
 
4
 
5
  import torch
6
  import transformers
 
7
 
8
+ try:
9
+ from .asr_modeling import ASRModel
10
+ except ImportError:
11
+ from asr_modeling import ASRModel # type: ignore[no-redef]
12
 
13
 
14
  class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
 
29
  model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
30
  )
31
 
 
 
 
 
 
32
  def __call__(self, inputs, **kwargs):
33
  generate_kwargs = {}
34
  generate_keys = [
 
86
  all_tokens.extend(tokens.tolist() if torch.is_tensor(tokens) else tokens)
87
 
88
  text = self.tokenizer.decode(all_tokens, skip_special_tokens=True).strip()
 
 
89
 
90
  return {"text": text}
91
 
 
100
  if "bytes" in inputs:
101
  inputs = self._decode_audio_bytes(inputs["bytes"])
102
  elif "array" in inputs:
103
+ inputs = {
104
+ "raw": inputs["array"],
105
+ "sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
106
+ }
107
+ elif "path" in inputs and "array" not in inputs:
108
+ # Lazy-loaded audio - load from path
109
+ inputs = self._decode_audio_bytes(Path(inputs["path"]).read_bytes())
110
  elif hasattr(inputs, "array") and hasattr(inputs, "sampling_rate"):
111
  inputs = {"raw": inputs.array, "sampling_rate": inputs.sampling_rate}
112
  elif hasattr(inputs, "__array__") and not isinstance(inputs, (dict, bytes, str)):
 
117
  "sampling_rate": self.model.config.audio_sample_rate,
118
  }
119
 
120
+ # Resample to target sample rate if needed (workaround for transformers bug)
121
+ # See: https://github.com/huggingface/transformers/pull/41298
122
+ if isinstance(inputs, dict) and "sampling_rate" in inputs:
123
+ in_sr = inputs["sampling_rate"]
124
+ target_sr = self.feature_extractor.sampling_rate
125
+ if in_sr != target_sr:
126
+ import librosa
127
+ import numpy as np
128
+
129
+ audio = inputs["raw"]
130
+ if hasattr(audio, "numpy"):
131
+ audio = audio.numpy()
132
+ resampled = librosa.resample(
133
+ np.asarray(audio, dtype=np.float32), orig_sr=in_sr, target_sr=target_sr
134
+ )
135
+ inputs = {"raw": resampled, "sampling_rate": target_sr}
136
+
137
  return super().preprocess(inputs, **preprocess_params)
138
 
139
  def _decode_audio_bytes(self, wav_bytes: bytes) -> dict[str, Any]:
 
243
  tokens = tokens[0]
244
 
245
  text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
 
 
246
 
247
  return {"text": text}
asr_processing.py CHANGED
@@ -4,7 +4,10 @@ from pathlib import Path
4
  import transformers
5
  from transformers import AutoTokenizer, ProcessorMixin
6
 
7
- from .asr_config import ASRConfig
 
 
 
8
 
9
 
10
  class ASRProcessor(ProcessorMixin):
 
4
  import transformers
5
  from transformers import AutoTokenizer, ProcessorMixin
6
 
7
+ try:
8
+ from .asr_config import ASRConfig
9
+ except ImportError:
10
+ from asr_config import ASRConfig # type: ignore[no-redef]
11
 
12
 
13
  class ASRProcessor(ProcessorMixin):
moe_projector.py CHANGED
@@ -11,14 +11,19 @@ class SimpleAdapter(nn.Module):
11
  projecting the hidden dimension from 3072 to 4096 and back to 3072."
12
  """
13
 
14
- def __init__(self, in_features, hidden_features, out_features):
15
  super().__init__()
16
  self.fc1 = nn.Linear(in_features, hidden_features)
17
  self.relu = nn.ReLU()
 
18
  self.fc2 = nn.Linear(hidden_features, out_features)
19
 
20
  def forward(self, x):
21
- return self.fc2(self.relu(self.fc1(x)))
 
 
 
 
22
 
23
 
24
  class MoEAudioProjector(nn.Module):
@@ -47,6 +52,9 @@ class MoEAudioProjector(nn.Module):
47
  # Adapter hidden dim: paper uses 4096
48
  adapter_hidden = getattr(config, "projector_hidden_dim", None) or 4096
49
 
 
 
 
50
  # --- Convolutional Subsampling (Section III-B) ---
51
  # "two convolutional layers, each with a kernel size of 3 and a stride of 2"
52
  # Maps encoder_dim (1280) -> llm_dim (3072), total stride=4
@@ -70,7 +78,7 @@ class MoEAudioProjector(nn.Module):
70
  # "projecting the hidden dimension from 3072 to 4096 and back to 3072"
71
  self.experts = nn.ModuleList(
72
  [
73
- SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim)
74
  for _ in range(self.num_experts)
75
  ]
76
  )
@@ -149,3 +157,7 @@ class MoEAudioProjector(nn.Module):
149
  final_out.add_(expert_out * expert_weight)
150
 
151
  return self.ln_post(final_out)
 
 
 
 
 
11
  projecting the hidden dimension from 3072 to 4096 and back to 3072."
12
  """
13
 
14
+ def __init__(self, in_features, hidden_features, out_features, dropout=0.0):
15
  super().__init__()
16
  self.fc1 = nn.Linear(in_features, hidden_features)
17
  self.relu = nn.ReLU()
18
+ self.dropout = nn.Dropout(dropout)
19
  self.fc2 = nn.Linear(hidden_features, out_features)
20
 
21
  def forward(self, x):
22
+ x = self.fc1(x)
23
+ x = self.relu(x)
24
+ x = self.dropout(x)
25
+ x = self.fc2(x)
26
+ return x
27
 
28
 
29
  class MoEAudioProjector(nn.Module):
 
52
  # Adapter hidden dim: paper uses 4096
53
  adapter_hidden = getattr(config, "projector_hidden_dim", None) or 4096
54
 
55
+ # Dropout rate for experts (not applied to router)
56
+ self.dropout_rate = getattr(config, "projector_dropout", 0.1)
57
+
58
  # --- Convolutional Subsampling (Section III-B) ---
59
  # "two convolutional layers, each with a kernel size of 3 and a stride of 2"
60
  # Maps encoder_dim (1280) -> llm_dim (3072), total stride=4
 
78
  # "projecting the hidden dimension from 3072 to 4096 and back to 3072"
79
  self.experts = nn.ModuleList(
80
  [
81
+ SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim, dropout=self.dropout_rate)
82
  for _ in range(self.num_experts)
83
  ]
84
  )
 
157
  final_out.add_(expert_out * expert_weight)
158
 
159
  return self.ln_post(final_out)
160
+
161
+ def get_aux_loss(self) -> torch.Tensor:
162
+ """Return auxiliary loss (none for dense MoE - all experts always used)."""
163
+ return torch.tensor(0.0)
shared_moe_projector.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared MoE Audio Projector.
2
+
3
+ A simplified MoE projector combining the best ideas:
4
+ - Shared expert: Always-on baseline processing (from GLM4)
5
+ - Zero-initialized router: Learns specialization naturally (from Qwen3)
6
+ - Simple top-k softmax: No grouping complexity (from Mixtral)
7
+ - Renormalized weights: Top-k weights sum to 1
8
+
9
+ Architecture:
10
+ Output = SharedExpert(x) + TopKRoutedExperts(x)
11
+
12
+ The shared expert ensures every audio token gets consistent baseline
13
+ processing, while routed experts can specialize for different patterns
14
+ (e.g., vowels vs consonants, silence vs speech).
15
+ """
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F # noqa: N812
20
+
21
+
22
+ class SharedExpert(nn.Module):
23
+ """Shared expert MLP that processes all tokens."""
24
+
25
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
26
+ super().__init__()
27
+ self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False)
28
+ self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
29
+ self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
30
+ self.act = nn.SiLU()
31
+
32
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
33
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
34
+
35
+
36
+ class SwiGLUExpert(nn.Module):
37
+ """Single SwiGLU expert MLP."""
38
+
39
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
40
+ super().__init__()
41
+ self.gate_proj = nn.Linear(input_dim, hidden_dim, bias=False)
42
+ self.up_proj = nn.Linear(input_dim, hidden_dim, bias=False)
43
+ self.down_proj = nn.Linear(hidden_dim, output_dim, bias=False)
44
+ self.act = nn.SiLU()
45
+
46
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
47
+ return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
48
+
49
+
50
+ class RoutedExperts(nn.Module):
51
+ """
52
+ Sparse routed experts using token dispatch.
53
+
54
+ For each expert, gathers assigned tokens, processes them, then scatters back.
55
+ Memory-efficient: O(num_tokens * hidden_dim) instead of
56
+ O(num_tokens * num_experts * hidden_dim * input_dim).
57
+ """
58
+
59
+ def __init__(
60
+ self, num_experts: int, top_k: int, input_dim: int, hidden_dim: int, output_dim: int
61
+ ):
62
+ super().__init__()
63
+ self.num_experts = num_experts
64
+ self.top_k = top_k
65
+ self.output_dim = output_dim
66
+
67
+ # ModuleList of expert MLPs
68
+ self.experts = nn.ModuleList([
69
+ SwiGLUExpert(input_dim, hidden_dim, output_dim)
70
+ for _ in range(num_experts)
71
+ ])
72
+
73
+ def forward(
74
+ self,
75
+ hidden_states: torch.Tensor,
76
+ top_k_indices: torch.Tensor,
77
+ top_k_weights: torch.Tensor,
78
+ ) -> torch.Tensor:
79
+ """
80
+ Token dispatch approach - memory efficient.
81
+
82
+ Args:
83
+ hidden_states: [num_tokens, input_dim]
84
+ top_k_indices: [num_tokens, top_k]
85
+ top_k_weights: [num_tokens, top_k]
86
+
87
+ Returns:
88
+ output: [num_tokens, output_dim]
89
+ """
90
+ num_tokens = hidden_states.shape[0]
91
+ device = hidden_states.device
92
+ dtype = hidden_states.dtype
93
+
94
+ # Output accumulator
95
+ output = torch.zeros(num_tokens, self.output_dim, device=device, dtype=dtype)
96
+
97
+ # Process each expert
98
+ for expert_idx, expert in enumerate(self.experts):
99
+ # Find which (token, slot) pairs use this expert
100
+ # top_k_indices: [N, K], we want all positions where value == expert_idx
101
+ expert_mask = top_k_indices == expert_idx # [N, K]
102
+
103
+ if not expert_mask.any():
104
+ continue
105
+
106
+ # Get token indices and slot indices where this expert is selected
107
+ token_indices, slot_indices = torch.where(expert_mask)
108
+
109
+ # Gather the tokens for this expert
110
+ expert_input = hidden_states[token_indices] # [num_selected, input_dim]
111
+
112
+ # Process through expert
113
+ expert_output = expert(expert_input) # [num_selected, output_dim]
114
+
115
+ # Get weights for these tokens at these slots
116
+ weights = top_k_weights[token_indices, slot_indices] # [num_selected]
117
+
118
+ # Weighted output
119
+ weighted_output = expert_output * weights.unsqueeze(-1)
120
+
121
+ # Scatter-add back to output
122
+ output.index_add_(0, token_indices, weighted_output)
123
+
124
+ return output
125
+
126
+
127
+ class SharedMoEBlock(nn.Module):
128
+ """MoE block with shared expert + sparse routed experts."""
129
+
130
+ def __init__(
131
+ self,
132
+ input_dim: int,
133
+ hidden_dim: int,
134
+ output_dim: int,
135
+ num_experts: int = 4,
136
+ top_k: int = 2,
137
+ ):
138
+ super().__init__()
139
+ self.num_experts = num_experts
140
+ self.top_k = top_k
141
+
142
+ # Router: zero-initialized for natural learning
143
+ self.router = nn.Linear(input_dim, num_experts, bias=False)
144
+ nn.init.zeros_(self.router.weight)
145
+
146
+ # Shared expert (always active)
147
+ self.shared_expert = SharedExpert(input_dim, hidden_dim, output_dim)
148
+
149
+ # Routed experts (sparse)
150
+ self.routed_experts = RoutedExperts(
151
+ num_experts, self.top_k, input_dim, hidden_dim, output_dim
152
+ )
153
+
154
+
155
+ # For auxiliary loss
156
+ self.last_router_logits = None
157
+
158
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
159
+ batch_size, seq_len, dim = hidden_states.shape
160
+
161
+ # Shared expert output (all tokens)
162
+ shared_out = self.shared_expert(hidden_states)
163
+
164
+ # Routing
165
+ flat_hidden = hidden_states.view(-1, dim)
166
+ router_logits = self.router(flat_hidden)
167
+ self.last_router_logits = router_logits
168
+
169
+ # Softmax -> top-k -> renormalize
170
+ router_probs = F.softmax(router_logits.float(), dim=-1)
171
+ top_k_weights, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
172
+ top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
173
+ top_k_weights = top_k_weights.to(hidden_states.dtype)
174
+
175
+ # Routed expert output
176
+ routed_out = self.routed_experts(flat_hidden, top_k_indices, top_k_weights)
177
+ routed_out = routed_out.view(batch_size, seq_len, -1)
178
+
179
+ # Combine: shared expert baseline + routed experts (grow in via zero-init down_proj)
180
+ return shared_out + routed_out
181
+
182
+
183
+ def load_balancing_loss(router_logits: torch.Tensor, num_experts: int, top_k: int) -> torch.Tensor:
184
+ """Auxiliary loss to encourage balanced expert usage."""
185
+ if router_logits is None:
186
+ return torch.tensor(0.0)
187
+
188
+ probs = F.softmax(router_logits.float(), dim=-1)
189
+ _, selected = torch.topk(probs, top_k, dim=-1)
190
+
191
+ # Fraction of tokens per expert
192
+ expert_mask = F.one_hot(selected, num_experts).float()
193
+ tokens_per_expert = expert_mask.mean(dim=(0, 1))
194
+
195
+ # Average probability per expert
196
+ prob_per_expert = probs.mean(dim=0)
197
+
198
+ # Balance loss
199
+ return (tokens_per_expert * prob_per_expert).sum() * num_experts
200
+
201
+
202
+ def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
203
+ """Z-loss to prevent router logits from growing too large.
204
+
205
+ From DeepSeek/Switch Transformer: penalizes large logits to keep
206
+ softmax in its "soft" regime where gradients flow properly.
207
+ """
208
+ if router_logits is None:
209
+ return torch.tensor(0.0)
210
+
211
+ # logsumexp ≈ max(logits), squaring penalizes large values
212
+ return torch.logsumexp(router_logits.float(), dim=-1).square().mean()
213
+
214
+
215
+ class SharedMoEAudioProjector(nn.Module):
216
+ """Shared MoE Audio Projector.
217
+
218
+ Combines a shared expert (always-on) with sparse routed experts.
219
+ Uses zero-initialized router for natural specialization learning.
220
+
221
+ Config options:
222
+ - num_experts: Number of routed experts (default: 4)
223
+ - num_experts_per_tok: Top-k routing (default: 2)
224
+ - router_aux_loss_coef: Load balancing loss weight (default: 0.01)
225
+ - router_z_loss_coef: Z-loss weight to prevent large logits (default: 0.001)
226
+ """
227
+
228
+ def __init__(self, config):
229
+ super().__init__()
230
+
231
+ # Temporal downsampling
232
+ self.k = getattr(config, "projector_pool_stride", 4)
233
+
234
+ # Dimensions
235
+ self.encoder_dim = config.encoder_dim
236
+ in_dim = self.encoder_dim * self.k
237
+ out_dim = config.llm_dim
238
+ # No expansion - keep hidden dim same as input dim
239
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or in_dim
240
+
241
+ # MoE config
242
+ self.num_experts = getattr(config, "num_experts", 4)
243
+ self.top_k = getattr(config, "num_experts_per_tok", 2)
244
+ self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.01)
245
+ self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)
246
+
247
+ # Layers
248
+ self.moe = SharedMoEBlock(in_dim, hidden_dim, out_dim, self.num_experts, self.top_k)
249
+
250
+ # Init
251
+ self._init_weights()
252
+
253
+ def _init_weights(self):
254
+ with torch.no_grad():
255
+ # Xavier init: std = 1/sqrt(fan_in)
256
+ in_dim = self.encoder_dim * self.k
257
+ std = 1.0 / (in_dim ** 0.5)
258
+
259
+ # Use a smaller std for the final projection in the shared expert's residual path
260
+ down_proj_std = std / 2.0
261
+
262
+ # Shared expert
263
+ nn.init.normal_(self.moe.shared_expert.gate_proj.weight, std=std)
264
+ nn.init.normal_(self.moe.shared_expert.up_proj.weight, std=std)
265
+ nn.init.normal_(self.moe.shared_expert.down_proj.weight, std=down_proj_std)
266
+
267
+ # Routed experts - zero init down_proj so they "grow in" from zero
268
+ for expert in self.moe.routed_experts.experts:
269
+ nn.init.normal_(expert.gate_proj.weight, std=std)
270
+ nn.init.normal_(expert.up_proj.weight, std=std)
271
+ nn.init.zeros_(expert.down_proj.weight)
272
+
273
+ # Router stays zero-initialized
274
+
275
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
276
+ batch_size, seq_len, dim = x.size()
277
+
278
+ # Dtype
279
+ target_dtype = self.moe.shared_expert.gate_proj.weight.dtype
280
+ if x.dtype != target_dtype:
281
+ x = x.to(target_dtype)
282
+
283
+ # Pad for pooling
284
+ if seq_len % self.k:
285
+ x = F.pad(x, (0, 0, 0, self.k - seq_len % self.k))
286
+ if attention_mask is not None:
287
+ attention_mask = F.pad(attention_mask, (0, self.k - seq_len % self.k), value=0)
288
+
289
+ # Store pooled attention mask for aux loss
290
+ if attention_mask is not None:
291
+ # Max-pool the attention mask
292
+ pooled_mask = F.max_pool1d(attention_mask.float().unsqueeze(1), self.k, self.k)
293
+ self.last_attention_mask = pooled_mask.squeeze(1).bool()
294
+ else:
295
+ self.last_attention_mask = None
296
+
297
+ # Temporal pooling
298
+ x = x.view(batch_size, -1, dim * self.k)
299
+
300
+ # Forward
301
+ x = self.moe(x)
302
+
303
+ return x
304
+
305
+ def get_aux_loss(self) -> torch.Tensor:
306
+ """Get auxiliary losses (call after forward).
307
+
308
+ Combines:
309
+ - Load balancing loss: encourages balanced expert usage
310
+ - Z-loss: prevents router logits from growing too large
311
+ """
312
+ router_logits = self.moe.last_router_logits
313
+ if router_logits is None:
314
+ return torch.tensor(0.0, device=self.moe.router.weight.device)
315
+
316
+ # Retrieve the attention mask stored during the forward pass
317
+ attention_mask = getattr(self, "last_attention_mask", None)
318
+
319
+ # If a mask exists, filter the logits to only include un-padded tokens
320
+ if attention_mask is not None:
321
+ flat_mask = attention_mask.view(-1)
322
+ # Ensure the mask is not all False, which would create an empty tensor
323
+ if flat_mask.any():
324
+ active_logits = router_logits[flat_mask]
325
+ else:
326
+ # If the mask is all False, there are no tokens to compute loss on
327
+ return torch.tensor(0.0, device=router_logits.device)
328
+ else:
329
+ active_logits = router_logits
330
+
331
+ balance_loss = load_balancing_loss(active_logits, self.num_experts, self.top_k)
332
+ z = z_loss(active_logits)
333
+
334
+ return self.aux_loss_coef * balance_loss + self.z_loss_coef * z
swiglu_projector.py CHANGED
@@ -25,34 +25,34 @@ class SwiGLU(nn.Module):
25
  class AudioProjector(nn.Module):
26
  def __init__(self, config):
27
  super().__init__()
28
- self.k = getattr(config, "projector_pool_stride", 4) # Downsampling rate
29
  in_dim = config.encoder_dim * self.k
30
  out_dim = config.llm_dim
31
  hidden_dim = config.projector_hidden_dim
32
  if hidden_dim is None:
33
- hidden_dim = config.encoder_dim * 4
34
 
35
  dropout_rate = getattr(config, "projector_dropout", 0.0)
36
 
37
- from transformers.models.llama.modeling_llama import LlamaRMSNorm
38
-
39
- self.ln_pre = LlamaRMSNorm(in_dim, eps=1e-6)
40
- self.proj = SwiGLU(in_dim, hidden_dim, out_dim, dropout=dropout_rate)
41
- self.ln_post = LlamaRMSNorm(out_dim, eps=1e-6)
42
  self.output_dropout = nn.Dropout(dropout_rate)
43
 
44
  with torch.no_grad():
45
  std = getattr(config, "projector_init_std", 0.02)
46
- self.ln_pre.weight.data.fill_(1.0)
47
- self.ln_post.weight.data.fill_(1.0)
48
- nn.init.normal_(self.proj.w1.weight, mean=0.0, std=std)
49
- nn.init.normal_(self.proj.w2.weight, mean=0.0, std=std)
50
- nn.init.normal_(self.proj.w3.weight, mean=0.0, std=std)
 
 
 
51
 
52
  def forward(self, x):
53
  batch_size, seq_len, dim = x.size()
54
 
55
- target_dtype = self.proj.w1.weight.dtype
56
  if x.dtype != target_dtype:
57
  x = x.to(target_dtype)
58
 
@@ -62,8 +62,7 @@ class AudioProjector(nn.Module):
62
  x = F.pad(x, (0, 0, 0, pad_len))
63
 
64
  x = x.contiguous().view(batch_size, -1, dim * self.k)
65
- x = self.ln_pre(x)
66
- x = self.proj(x)
67
- x = self.ln_post(x)
68
 
69
  return self.output_dropout(x)
 
25
  class AudioProjector(nn.Module):
26
  def __init__(self, config):
27
  super().__init__()
28
+ self.k = getattr(config, "projector_pool_stride", 4)
29
  in_dim = config.encoder_dim * self.k
30
  out_dim = config.llm_dim
31
  hidden_dim = config.projector_hidden_dim
32
  if hidden_dim is None:
33
+ hidden_dim = config.encoder_dim * 2
34
 
35
  dropout_rate = getattr(config, "projector_dropout", 0.0)
36
 
37
+ self.proj1 = SwiGLU(in_dim, hidden_dim, hidden_dim, dropout=dropout_rate)
38
+ self.proj2 = SwiGLU(hidden_dim, hidden_dim, out_dim, dropout=dropout_rate)
 
 
 
39
  self.output_dropout = nn.Dropout(dropout_rate)
40
 
41
  with torch.no_grad():
42
  std = getattr(config, "projector_init_std", 0.02)
43
+ # Initialize first layer
44
+ nn.init.normal_(self.proj1.w1.weight, mean=0.0, std=std)
45
+ nn.init.normal_(self.proj1.w2.weight, mean=0.0, std=std)
46
+ nn.init.normal_(self.proj1.w3.weight, mean=0.0, std=std)
47
+ # Initialize second layer
48
+ nn.init.normal_(self.proj2.w1.weight, mean=0.0, std=std)
49
+ nn.init.normal_(self.proj2.w2.weight, mean=0.0, std=std)
50
+ nn.init.normal_(self.proj2.w3.weight, mean=0.0, std=std)
51
 
52
  def forward(self, x):
53
  batch_size, seq_len, dim = x.size()
54
 
55
+ target_dtype = self.proj1.w1.weight.dtype
56
  if x.dtype != target_dtype:
57
  x = x.to(target_dtype)
58
 
 
62
  x = F.pad(x, (0, 0, 0, pad_len))
63
 
64
  x = x.contiguous().view(batch_size, -1, dim * self.k)
65
+ x = self.proj1(x)
66
+ x = self.proj2(x)
 
67
 
68
  return self.output_dropout(x)
tokenizer.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d4aeaf198f783cbf58d8cd59812baac429ffe49147bf9648f6618de20b8d4a4c
3
- size 17209003
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64999f2f5e05d34613701df1999669c5dce7e3891e1628a002518ee68a8626d1
3
+ size 17209101