from pathlib import Path from typing import Optional, Union import torch import torch.nn as nn import torch.nn.functional as F # noqa: N812 from transformers import ( AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, ) from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.whisper.modeling_whisper import ( _compute_mask_indices, ) try: from .asr_config import ASRConfig from .moe_projector import MoEAudioProjector from .residual_projector import ResidualAudioProjector from .swiglu_projector import AudioProjector from .shared_moe_projector import SharedMoEAudioProjector except ImportError: from asr_config import ASRConfig # type: ignore[no-redef] from moe_projector import MoEAudioProjector # type: ignore[no-redef] from residual_projector import ResidualAudioProjector # type: ignore[no-redef] from swiglu_projector import AudioProjector # type: ignore[no-redef] from shared_moe_projector import SharedMoEAudioProjector # type: ignore[no-redef] # Map projector type names to classes PROJECTOR_CLASSES = { "swiglu": AudioProjector, "residual": ResidualAudioProjector, "moe": MoEAudioProjector, "shared_moe": SharedMoEAudioProjector, } class ASRModel(PreTrainedModel): """Audio-to-text model combining an audio encoder, projector, and language model.""" config_class = ASRConfig base_model_prefix = "model" main_input_name = "input_features" _supports_flash_attn_2 = True supports_gradient_checkpointing = False # Frozen encoder/LLM don't benefit; projector is small _is_loading_from_pretrained: bool = False _pretrained_model_path: Optional[str] = None TASK_PROMPTS = { "transcribe": "Transcribe: