| from collections.abc import Callable |
| from typing import Optional |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch import Tensor, nn |
| from transformers import PreTrainedModel, Qwen2ForCausalLM |
| from transformers.activations import ACT2FN |
| from transformers.generation import GenerationMixin |
| from transformers.modeling_layers import GradientCheckpointingLayer |
| from transformers.modeling_outputs import BaseModelOutput |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
| from transformers.utils import auto_docstring |
| from .configuration_uas_audio import UASAudioConfig, UASAudioEncoderConfig, UASAudioEncoderOnlyConfig |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) |
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) |
|
|
|
|
| def _get_feat_extract_output_lengths(input_lengths): |
| """ |
| Computes the output length of the convolutional layers and the output length of the audio encoder |
| """ |
|
|
| input_lengths_leave = input_lengths % 100 |
| feat_lengths = (input_lengths_leave - 1) // 2 + 1 |
| output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 |
| return output_lengths |
|
|
|
|
| def eager_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| scaling: float, |
| dropout: float = 0.0, |
| **kwargs, |
| ): |
| key_states = repeat_kv(key, module.num_key_value_groups) |
| value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
| if attention_mask is not None: |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| attn_weights = attn_weights + causal_mask |
|
|
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
| attn_output = torch.matmul(attn_weights, value_states) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
| return attn_output, attn_weights |
|
|
|
|
| class SinusoidsPositionEmbedding(nn.Module): |
| def __init__(self, length, channels, max_timescale=10000): |
| super().__init__() |
| if channels % 2 != 0: |
| raise ValueError("SinusoidsPositionEmbedding needs even channels input") |
| log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) |
| inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) |
| scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] |
| self.register_buffer( |
| "positional_embedding", |
| torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), |
| persistent=False, |
| ) |
|
|
| def forward(self, seqlen: int): |
| return self.positional_embedding[:seqlen, :] |
|
|
|
|
| class UASAudioAttention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.embed_dim = config.d_model |
| self.num_heads = config.encoder_attention_heads |
| self.dropout = config.attention_dropout |
| self.head_dim = self.embed_dim // self.num_heads |
| self.num_key_value_groups = 1 |
| self.config = config |
|
|
| if (self.head_dim * self.num_heads) != self.embed_dim: |
| raise ValueError( |
| f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" |
| f" and `num_heads`: {self.num_heads})." |
| ) |
| self.scaling = self.head_dim**-0.5 |
| self.attention_dropout = 0.0 |
| self.is_decoder = False |
| self.is_causal = False |
| self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) |
| self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) |
| self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) |
| self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: |
| """Input shape: Batch x Time x Channel""" |
|
|
| seq_length, _ = hidden_states.size() |
|
|
| query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1) |
| key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1) |
| value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1) |
|
|
| query_states = query_states.transpose(0, 1).unsqueeze(0) |
| key_states = key_states.transpose(0, 1).unsqueeze(0) |
| value_states = value_states.transpose(0, 1).unsqueeze(0) |
| max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() |
|
|
| attention_interface: Callable = eager_attention_forward |
| if self.config._attn_implementation != "eager": |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
| attn_output, _ = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask=attention_mask, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| cu_seq_lens_q=cu_seqlens, |
| cu_seq_lens_k=cu_seqlens, |
| max_length_q=max_seqlen, |
| max_length_k=max_seqlen, |
| is_causal=False, |
| **kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(seq_length, -1).contiguous() |
| attn_output = self.out_proj(attn_output) |
|
|
| return attn_output |
|
|
|
|
| class UASAudioEncoderLayer(GradientCheckpointingLayer): |
| def __init__(self, config: UASAudioEncoderConfig): |
| super().__init__() |
| self.embed_dim = config.d_model |
| self.self_attn = UASAudioAttention(config) |
| self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) |
| self.dropout = config.dropout |
| self.activation_fn = ACT2FN[config.activation_function] |
| self.activation_dropout = config.activation_dropout |
| self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) |
| self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) |
| self.final_layer_norm = nn.LayerNorm(self.embed_dim) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
| attention_mask (`torch.FloatTensor`): attention mask of size |
| `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| """ |
| residual = hidden_states |
| hidden_states = self.self_attn_layer_norm(hidden_states) |
| hidden_states = self.self_attn( |
| hidden_states=hidden_states, |
| cu_seqlens=cu_seqlens, |
| attention_mask=attention_mask, |
| **kwargs, |
| ) |
| hidden_states = residual + hidden_states |
| residual = hidden_states |
| hidden_states = self.final_layer_norm(hidden_states) |
| hidden_states = self.fc1(hidden_states) |
| hidden_states = self.activation_fn(hidden_states) |
| hidden_states = self.fc2(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| if hidden_states.dtype == torch.float16: |
| clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
|
|
| outputs = (hidden_states,) |
|
|
| return outputs |
|
|
|
|
| class UASAudioEncoder(PreTrainedModel): |
| config: UASAudioEncoderConfig |
| main_input_name = "input_features" |
| input_modalities = "audio" |
| _no_split_modules = ["UASAudioEncoderLayer"] |
| _supports_sdpa = True |
|
|
| def __init__(self, config: UASAudioEncoderConfig): |
| super().__init__(config) |
| self.dropout = config.dropout |
|
|
| embed_dim = config.d_model |
| self.num_mel_bins = config.num_mel_bins |
| self.max_source_positions = config.max_source_positions |
| self.n_window = config.n_window |
| self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim) |
| self.layers = nn.ModuleList([UASAudioEncoderLayer(config) for _ in range(config.encoder_layers)]) |
| self.ln_post = nn.LayerNorm(config.d_model) |
| self.gradient_checkpointing = False |
| self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1) |
| self.conv2d2 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1) |
| self.conv2d3 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1) |
| self.conv_out = nn.Linear( |
| config.downsample_hidden_size * ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2), |
| config.d_model, |
| bias=False, |
| ) |
| self.n_window_infer = self.config.n_window_infer |
| self.conv_chunksize = self.config.conv_chunksize |
| self.post_init() |
|
|
| def _freeze_parameters(self): |
| for param in self.parameters(): |
| param.requires_grad = False |
| self._requires_grad = False |
|
|
| def get_input_embeddings(self) -> nn.Module: |
| return self.conv1 |
|
|
| def set_input_embeddings(self, value: nn.Module): |
| self.conv1 = value |
|
|
| def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor: |
| |
| |
| |
| |
| if self.config._attn_implementation == "flash_attention_2": |
| return None |
|
|
| seq_length = inputs_tensor.shape[0] |
| attention_mask = torch.full( |
| [1, 1, seq_length, seq_length], |
| torch.finfo(inputs_tensor.dtype).min, |
| device=inputs_tensor.device, |
| dtype=inputs_tensor.dtype, |
| ) |
| for i in range(1, len(cu_seqlens)): |
| attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 |
| return attention_mask |
|
|
| @auto_docstring |
| def forward( |
| self, |
| input_features, |
| feature_lens=None, |
| aftercnn_lens=None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ): |
| r""" |
| feature_lens (`torch.LongTensor` of shape `(batch_size,)`): |
| mel length |
| aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`): |
| mel length after cnn |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| """ |
| return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', False) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else getattr(self.config, 'output_hidden_states', False) |
| ) |
|
|
| aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) |
| chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() |
|
|
| chunk_lengths = torch.tensor( |
| [self.n_window * 2] * chunk_num.sum(), |
| dtype=torch.long, |
| device=feature_lens.device, |
| ) |
| tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] |
| chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) |
| chunk_lengths[chunk_lengths == 0] = self.n_window * 2 |
|
|
| chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) |
| padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2) |
| feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) |
| padded_mask_after_cnn = nn.utils.rnn.pad_sequence( |
| [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn], |
| batch_first=True, |
| ) |
| padded_feature = padded_feature.unsqueeze(1) |
| |
| padded_embeds = [] |
| for chunk in padded_feature.split(self.conv_chunksize, dim=0): |
| padded_embed = F.gelu(self.conv2d1(chunk)) |
| padded_embed = F.gelu(self.conv2d2(padded_embed)) |
| padded_embed = F.gelu(self.conv2d3(padded_embed)) |
| padded_embeds.append(padded_embed) |
| padded_embed = torch.cat(padded_embeds, dim=0) |
| b, c, f, t = padded_embed.size() |
| padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)) |
|
|
| positional_embedding = ( |
| self.positional_embedding.positional_embedding[: padded_embed.shape[1], :] |
| .unsqueeze(0) |
| .to(padded_embed.dtype) |
| ) |
| padded_embed = padded_embed + positional_embedding |
| hidden_states = padded_embed[padded_mask_after_cnn] |
| cu_chunk_lens = [0] |
| window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2)) |
| for cnn_len in aftercnn_lens: |
| cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) |
| remainder = cnn_len % window_aftercnn |
| if remainder != 0: |
| cu_chunk_lens += [remainder] |
| cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32) |
|
|
| all_hidden_states = () if output_hidden_states else None |
| if output_hidden_states: |
| all_hidden_states = (hidden_states,) |
|
|
| for layer_idx, encoder_layer in enumerate(self.layers): |
| layer_outputs = encoder_layer( |
| hidden_states, |
| cu_seqlens, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| hidden_states = self.ln_post(hidden_states) |
|
|
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| if not return_dict: |
| return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) |
|
|
| return BaseModelOutput( |
| last_hidden_state=hidden_states, |
| hidden_states=all_hidden_states, |
| ) |
|
|
| def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"): |
| """ |
| Pads a sequence of tensors to their maximum length on indicated `padding_side`. |
| Then prepares a mask so that pad tokens are not attended to. |
| """ |
| max_len = tensor_len.max() |
| dim = tensor_list[0].shape[0] |
| padded_tensor = torch.full( |
| size=(len(tensor_list), dim, max_len), |
| fill_value=padding_value, |
| dtype=self.dtype, |
| device=tensor_list[0].device, |
| ) |
|
|
| batch_mask = torch.zeros( |
| (len(tensor_len), max_len), |
| dtype=torch.long, |
| device=padded_tensor.device, |
| ) |
| for i, length in enumerate(tensor_len): |
| batch_mask[i, :length] = 1 |
| padded_tensor[i, :, :length] = tensor_list[i] |
|
|
| feature_lens_after_cnn = (tensor_len - 1) // 2 + 1 |
| max_len_after_cnn = feature_lens_after_cnn.max() |
| batch_mask_after_cnn = torch.zeros( |
| (len(tensor_len), max_len_after_cnn), |
| dtype=torch.long, |
| device=padded_tensor.device, |
| ) |
| for i, length in enumerate(feature_lens_after_cnn): |
| batch_mask_after_cnn[i, :length] = 1 |
| return ( |
| padded_tensor, |
| batch_mask.unsqueeze(1), |
| batch_mask_after_cnn.bool(), |
| ) |
|
|
|
|
| class Adapter(nn.Module): |
| def __init__( |
| self, |
| d_model: int, |
| n_embd: int, |
| ): |
| super().__init__() |
| self.audio_projector = torch.nn.Sequential( |
| torch.nn.Linear(d_model, n_embd), |
| torch.nn.GELU(), |
| torch.nn.Linear(n_embd, n_embd) |
| ) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| x = self.audio_projector(x) |
| return x |
|
|
|
|
| class UASAudioForCausalLM(PreTrainedModel, GenerationMixin): |
| config_class = UASAudioConfig |
| main_input_name = "input_ids" |
| supports_gradient_checkpointing = True |
| def __init__(self, config: UASAudioConfig): |
| super().__init__(config) |
| if isinstance(config.dtype, str): |
| dtype = getattr(torch, config.dtype) |
| else: |
| dtype = config.dtype |
| self.bf16 = dtype == torch.bfloat16 |
|
|
| self.llm = Qwen2ForCausalLM(config.text_config) |
| self.audio_encoder = UASAudioEncoder(config.audio_encoder_config) |
|
|
| d_model = config.audio_encoder_config.d_model |
|
|
| self.adapter = Adapter( |
| d_model, |
| config.text_config.hidden_size, |
| ) |
| self.audio_token = config.audio_token |
|
|
| if self.bf16: |
| self.audio_encoder = self.audio_encoder.bfloat16() |
| self.adapter = self.adapter.bfloat16() |
|
|
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| mels=None, |
| mel_masks=None, |
| past_key_values=None, |
| **kwargs |
| ): |
| |
| if past_key_values is not None: |
| outputs = self.llm( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| **kwargs |
| ) |
| else: |
| |
| hidden_states = self.embedding_with_audio_tokens(input_ids, mels, mel_masks) |
| outputs = self.llm( |
| inputs_embeds=hidden_states, |
| attention_mask=attention_mask, |
| past_key_values=None, |
| **kwargs |
| ) |
| return outputs |
|
|
| def embedding_with_audio_tokens( |
| self, |
| input_ids, |
| mels, |
| mel_masks |
| ): |
| """ |
| Get input embeddings for the LLM, replacing audio token positions with audio features from the audio encoder. |
| """ |
| hidden_states = self.embeddings(input_ids) |
| if mels is None: |
| return hidden_states |
|
|
| audio_embeddings = self.audio_encoding(mels, mel_masks) |
| audio_embeddings = self.adapter(audio_embeddings) |
| audio_mask = input_ids == self.audio_token |
| hidden_states[audio_mask] = audio_embeddings |
| return hidden_states |
|
|
| def audio_encoding( |
| self, |
| audio_features: torch.Tensor, |
| audio_features_mask: torch.Tensor, |
| output_hidden_states: bool = False |
| ): |
| """ |
| Encode audio features into embeddings. |
| |
| Args: |
| audio_features: Audio features tensor |
| audio_features_mask: Audio features mask |
| output_hidden_states: Whether to return hidden states from all encoder layers |
| |
| Returns: |
| If output_hidden_states=False: audio_features_encoded tensor |
| If output_hidden_states=True: BaseModelOutput with last_hidden_state and hidden_states |
| """ |
| feature_lens = audio_features_mask.sum(-1).long() |
| input_features = audio_features.permute(0, 2, 1)[audio_features_mask.bool()].permute(1, 0) |
|
|
| audio_encoder_outputs = self.audio_encoder( |
| input_features, |
| feature_lens=feature_lens, |
| output_hidden_states=output_hidden_states, |
| return_dict=output_hidden_states, |
| ) |
|
|
| if output_hidden_states: |
| |
| return audio_encoder_outputs |
| else: |
| |
| |
| if isinstance(audio_encoder_outputs, tuple): |
| return audio_encoder_outputs[0] |
| return audio_encoder_outputs |
|
|
| @property |
| def embeddings(self): |
| """Return the model's input embeddings - required for GenerationMixin""" |
| return self.llm.model.embed_tokens |
|
|
| def forward_with_detailed_outputs( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| mels=None, |
| mel_masks=None, |
| past_key_values=None, |
| output_hidden_states: bool = True, |
| **kwargs |
| ): |
| """ |
| Forward pass that returns detailed outputs including: |
| - Audio encoder final output |
| - Audio features after projector (adapter) |
| - Text embedding features |
| - Hidden states from each layer (separated for audio and text) |
| |
| Args: |
| input_ids: Input token ids |
| attention_mask: Attention mask |
| mels: Audio mel features |
| mel_masks: Audio mel masks |
| past_key_values: Past key values for generation |
| output_hidden_states: Whether to return hidden states from all layers |
| **kwargs: Additional arguments |
| |
| Returns: |
| dict containing: |
| - audio_encoder_output: Final output from audio encoder |
| - audio_features_after_adapter: Audio features after projector/adapter |
| - text_embeddings: Text embedding features (before audio replacement) |
| - audio_encoder_hidden_states: Tuple of hidden states from each audio encoder layer |
| - llm_hidden_states: Tuple of hidden states from each LLM layer (mixed audio+text) |
| - llm_hidden_states_text_only: Tuple of text-only hidden states from each LLM layer |
| - llm_hidden_states_audio_only: Tuple of audio-only hidden states from each LLM layer |
| - llm_outputs: Full LLM outputs (CausalLMOutputWithPast) |
| """ |
| |
| |
| text_embeddings_pure = self.embeddings(input_ids) |
|
|
| |
| audio_encoder_output = None |
| audio_features_after_adapter = None |
| audio_encoder_hidden_states = None |
| audio_mask = None |
|
|
| |
| input_embeddings_for_llm = text_embeddings_pure.clone() |
|
|
| |
| audio_mask = input_ids == self.audio_token |
|
|
| if mels is not None: |
| |
| audio_encoder_outputs = self.audio_encoding( |
| mels, |
| mel_masks, |
| output_hidden_states=output_hidden_states |
| ) |
|
|
| if output_hidden_states: |
| audio_encoder_output = audio_encoder_outputs.last_hidden_state |
| audio_encoder_hidden_states = audio_encoder_outputs.hidden_states |
| else: |
| audio_encoder_output = audio_encoder_outputs |
| audio_encoder_hidden_states = None |
|
|
| |
| audio_features_after_adapter = self.adapter(audio_encoder_output) |
|
|
| |
| input_embeddings_for_llm[audio_mask] = audio_features_after_adapter |
|
|
| |
| |
| if audio_mask.any(): |
| |
| batch_size = text_embeddings_pure.shape[0] |
| text_embeddings_list = [] |
|
|
| for i in range(batch_size): |
| |
| text_mask = ~audio_mask[i] |
| |
| text_emb = text_embeddings_pure[i][text_mask] |
| text_embeddings_list.append(text_emb) |
|
|
| |
| |
| max_text_len = max(emb.shape[0] for emb in text_embeddings_list) if text_embeddings_list else 0 |
| if max_text_len > 0: |
| hidden_size = text_embeddings_pure.shape[2] |
| text_embeddings_pure = torch.zeros( |
| (batch_size, max_text_len, hidden_size), |
| dtype=text_embeddings_pure.dtype, |
| device=text_embeddings_pure.device |
| ) |
| for i, emb in enumerate(text_embeddings_list): |
| text_len = emb.shape[0] |
| text_embeddings_pure[i, :text_len] = emb |
| else: |
| |
| hidden_size = text_embeddings_pure.shape[2] |
| text_embeddings_pure = torch.zeros( |
| (batch_size, 0, hidden_size), |
| dtype=text_embeddings_pure.dtype, |
| device=text_embeddings_pure.device |
| ) |
| |
|
|
| |
| if past_key_values is not None: |
| |
| llm_outputs = self.llm( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| output_hidden_states=output_hidden_states, |
| return_dict=True, |
| **kwargs |
| ) |
| else: |
| |
| llm_outputs = self.llm( |
| inputs_embeds=input_embeddings_for_llm, |
| attention_mask=attention_mask, |
| past_key_values=None, |
| output_hidden_states=output_hidden_states, |
| return_dict=True, |
| **kwargs |
| ) |
|
|
| |
| llm_hidden_states = llm_outputs.hidden_states if output_hidden_states else None |
|
|
| |
| llm_hidden_states_text_only = None |
| llm_hidden_states_audio_only = None |
|
|
| if output_hidden_states and llm_hidden_states is not None and audio_mask is not None: |
| |
| llm_hidden_states_text_only = tuple() |
| llm_hidden_states_audio_only = tuple() |
|
|
| batch_size = llm_hidden_states[0].shape[0] |
| hidden_size = llm_hidden_states[0].shape[2] |
|
|
| for layer_hidden_states in llm_hidden_states: |
| |
| |
|
|
| |
| text_hidden_list = [] |
| audio_hidden_list = [] |
|
|
| for i in range(batch_size): |
| |
| text_mask = ~audio_mask[i] |
| audio_mask_i = audio_mask[i] |
|
|
| |
| text_hidden_i = layer_hidden_states[i][text_mask] |
| text_hidden_list.append(text_hidden_i) |
|
|
| |
| audio_hidden_i = layer_hidden_states[i][audio_mask_i] |
| audio_hidden_list.append(audio_hidden_i) |
|
|
| |
| |
| max_text_len = max(emb.shape[0] for emb in text_hidden_list) if text_hidden_list else 0 |
| max_audio_len = max(emb.shape[0] for emb in audio_hidden_list) if audio_hidden_list else 0 |
|
|
| if max_text_len > 0: |
| text_hidden = torch.zeros( |
| (batch_size, max_text_len, hidden_size), |
| dtype=layer_hidden_states.dtype, |
| device=layer_hidden_states.device |
| ) |
| for i, emb in enumerate(text_hidden_list): |
| text_len = emb.shape[0] |
| text_hidden[i, :text_len] = emb |
| else: |
| |
| text_hidden = torch.zeros( |
| (batch_size, 0, hidden_size), |
| dtype=layer_hidden_states.dtype, |
| device=layer_hidden_states.device |
| ) |
|
|
| if max_audio_len > 0: |
| audio_hidden = torch.zeros( |
| (batch_size, max_audio_len, hidden_size), |
| dtype=layer_hidden_states.dtype, |
| device=layer_hidden_states.device |
| ) |
| for i, emb in enumerate(audio_hidden_list): |
| audio_len = emb.shape[0] |
| audio_hidden[i, :audio_len] = emb |
| else: |
| |
| audio_hidden = torch.zeros( |
| (batch_size, 0, hidden_size), |
| dtype=layer_hidden_states.dtype, |
| device=layer_hidden_states.device |
| ) |
|
|
| llm_hidden_states_text_only += (text_hidden,) |
| llm_hidden_states_audio_only += (audio_hidden,) |
|
|
| return { |
| "audio_encoder_output": audio_encoder_output, |
| "audio_features_after_adapter": audio_features_after_adapter, |
| "text_embeddings": text_embeddings_pure, |
| "audio_encoder_hidden_states": audio_encoder_hidden_states, |
| "llm_hidden_states": llm_hidden_states, |
| "llm_hidden_states_text_only": llm_hidden_states_text_only, |
| "llm_hidden_states_audio_only": llm_hidden_states_audio_only, |
| "llm_outputs": llm_outputs, |
| } |
|
|
| def generate( |
| self, |
| input_ids, |
| attention_mask=None, |
| mels=None, |
| mel_masks=None, |
| generation_config=None, |
| **generate_kwargs |
| ): |
| """ |
| New implementation of the generate method to support audio inputs. |
| |
| This method will: |
| 1. Handle the initial processing of audio inputs; |
| 2. Call the underlying LLM's generate method with the appropriate embeddings; |
| 3. The incremental decoding will be handled by the LLM's generate method using past_key_values. |
| """ |
| |
| input_embeddings = self.embedding_with_audio_tokens(input_ids, mels, mel_masks) |
|
|
| |
| |
| |
| outputs = self.llm.generate( |
| inputs_embeds=input_embeddings, |
| attention_mask=attention_mask, |
| generation_config=generation_config, |
| use_cache=True, |
| **generate_kwargs |
| ) |
|
|
| return outputs |
|
|
|
|
| class UASAudioEncoderOnly(PreTrainedModel): |
| """ |
| UASAudio encoder-only model that contains only the audio encoder and adapter. |
| Input: audio features |
| Output: features processed by encoder and adapter |
| """ |
| config_class = UASAudioEncoderOnlyConfig |
| main_input_name = "input_features" |
| input_modalities = "audio" |
|
|
| def __init__(self, config: UASAudioEncoderOnlyConfig): |
| super().__init__(config) |
| if isinstance(config.dtype, str): |
| dtype = getattr(torch, config.dtype) |
| else: |
| dtype = getattr(config, "dtype", torch.bfloat16) |
| self.bf16 = dtype == torch.bfloat16 |
|
|
| self.audio_encoder = UASAudioEncoder(config.audio_encoder_config) |
|
|
| d_model = config.audio_encoder_config.d_model |
|
|
| self.adapter = Adapter( |
| d_model, |
| config.hidden_size, |
| ) |
|
|
| if self.bf16: |
| self.audio_encoder = self.audio_encoder.bfloat16() |
| self.adapter = self.adapter.bfloat16() |
|
|
| self.post_init() |
|
|
| def forward( |
| self, |
| input_features: torch.Tensor, |
| feature_lens: Optional[torch.Tensor] = None, |
| **kwargs |
| ): |
| """ |
| Forward pass through audio encoder and adapter. |
| |
| Args: |
| input_features: Audio features tensor of shape (seq_len, num_mel_bins) or (batch, seq_len, num_mel_bins) |
| feature_lens: Optional tensor of shape (batch_size,) indicating the length of each sequence |
| **kwargs: Additional arguments passed to audio encoder |
| |
| Returns: |
| torch.Tensor: Features processed by encoder and adapter |
| """ |
| |
| audio_features_encoded = self.audio_encoder( |
| input_features, |
| feature_lens=feature_lens, |
| **kwargs |
| ) |
|
|
| |
| if isinstance(audio_features_encoded, tuple): |
| audio_features_encoded = audio_features_encoded[0] |
| elif hasattr(audio_features_encoded, "last_hidden_state"): |
| audio_features_encoded = audio_features_encoded.last_hidden_state |
|
|
| |
| output_features = self.adapter(audio_features_encoded) |
|
|
| return output_features |
|
|