Unified_Audio_Schema / modeling_uas_audio.py
root
initial commit
8fb7827
from collections.abc import Callable
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from transformers import PreTrainedModel, Qwen2ForCausalLM
from transformers.activations import ACT2FN
from transformers.generation import GenerationMixin
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.utils import auto_docstring
from .configuration_uas_audio import UASAudioConfig, UASAudioEncoderConfig, UASAudioEncoderOnlyConfig
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def _get_feat_extract_output_lengths(input_lengths):
"""
Computes the output length of the convolutional layers and the output length of the audio encoder
"""
input_lengths_leave = input_lengths % 100
feat_lengths = (input_lengths_leave - 1) // 2 + 1
output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
return output_lengths
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class SinusoidsPositionEmbedding(nn.Module):
def __init__(self, length, channels, max_timescale=10000):
super().__init__()
if channels % 2 != 0:
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float())
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
self.register_buffer(
"positional_embedding",
torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
persistent=False,
)
def forward(self, seqlen: int):
return self.positional_embedding[:seqlen, :]
class UASAudioAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config):
super().__init__()
self.embed_dim = config.d_model
self.num_heads = config.encoder_attention_heads
self.dropout = config.attention_dropout
self.head_dim = self.embed_dim // self.num_heads
self.num_key_value_groups = 1 # needed for eager attention
self.config = config
if (self.head_dim * self.num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)
self.scaling = self.head_dim**-0.5
self.attention_dropout = 0.0
self.is_decoder = False
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
seq_length, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
key_states = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
value_states = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
cu_seq_lens_q=cu_seqlens, # pass cu seq lens for FA2
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output
class UASAudioEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: UASAudioEncoderConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = UASAudioAttention(config)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
attention_mask=attention_mask,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
return outputs
class UASAudioEncoder(PreTrainedModel):
config: UASAudioEncoderConfig
main_input_name = "input_features"
input_modalities = "audio"
_no_split_modules = ["UASAudioEncoderLayer"]
_supports_sdpa = True
def __init__(self, config: UASAudioEncoderConfig):
super().__init__(config)
self.dropout = config.dropout
embed_dim = config.d_model
self.num_mel_bins = config.num_mel_bins
self.max_source_positions = config.max_source_positions
self.n_window = config.n_window
self.positional_embedding = SinusoidsPositionEmbedding(self.max_source_positions, embed_dim)
self.layers = nn.ModuleList([UASAudioEncoderLayer(config) for _ in range(config.encoder_layers)])
self.ln_post = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
self.conv2d1 = nn.Conv2d(1, config.downsample_hidden_size, 3, 2, padding=1)
self.conv2d2 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1)
self.conv2d3 = nn.Conv2d(config.downsample_hidden_size, config.downsample_hidden_size, 3, 2, padding=1)
self.conv_out = nn.Linear(
config.downsample_hidden_size * ((((config.num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2),
config.d_model,
bias=False,
)
self.n_window_infer = self.config.n_window_infer
self.conv_chunksize = self.config.conv_chunksize
self.post_init()
def _freeze_parameters(self):
for param in self.parameters():
param.requires_grad = False
self._requires_grad = False
def get_input_embeddings(self) -> nn.Module:
return self.conv1
def set_input_embeddings(self, value: nn.Module):
self.conv1 = value
def _prepare_attention_mask(self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor) -> torch.Tensor:
# Flash Attention 2 doesn't need a 4D mask and relies on `cu_seqlens/max_seqlen`
# NOTE: the created attention masl only approximates the ragged FA2 attention by
# allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
# blocks. Though it will not be a 100% match for FA2's `varlen` path
if self.config._attn_implementation == "flash_attention_2":
return None
seq_length = inputs_tensor.shape[0]
attention_mask = torch.full(
[1, 1, seq_length, seq_length],
torch.finfo(inputs_tensor.dtype).min,
device=inputs_tensor.device,
dtype=inputs_tensor.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
return attention_mask
@auto_docstring
def forward(
self,
input_features,
feature_lens=None,
aftercnn_lens=None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
feature_lens (`torch.LongTensor` of shape `(batch_size,)`):
mel length
aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`):
mel length after cnn
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
return_dict = return_dict if return_dict is not None else getattr(self.config, 'use_return_dict', False)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else getattr(self.config, 'output_hidden_states', False)
)
aftercnn_lens = _get_feat_extract_output_lengths(feature_lens)
chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
chunk_lengths = torch.tensor(
[self.n_window * 2] * chunk_num.sum(),
dtype=torch.long,
device=feature_lens.device,
)
tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
chunk_lengths[chunk_lengths == 0] = self.n_window * 2
chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2)
feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths)
padded_mask_after_cnn = nn.utils.rnn.pad_sequence(
[torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn],
batch_first=True,
)
padded_feature = padded_feature.unsqueeze(1)
# Split to chunk to avoid OOM during convolution
padded_embeds = []
for chunk in padded_feature.split(self.conv_chunksize, dim=0):
padded_embed = F.gelu(self.conv2d1(chunk))
padded_embed = F.gelu(self.conv2d2(padded_embed))
padded_embed = F.gelu(self.conv2d3(padded_embed))
padded_embeds.append(padded_embed)
padded_embed = torch.cat(padded_embeds, dim=0)
b, c, f, t = padded_embed.size()
padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f))
positional_embedding = (
self.positional_embedding.positional_embedding[: padded_embed.shape[1], :]
.unsqueeze(0)
.to(padded_embed.dtype)
)
padded_embed = padded_embed + positional_embedding
hidden_states = padded_embed[padded_mask_after_cnn]
cu_chunk_lens = [0]
window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2))
for cnn_len in aftercnn_lens:
cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn)
remainder = cnn_len % window_aftercnn
if remainder != 0:
cu_chunk_lens += [remainder]
cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32)
all_hidden_states = () if output_hidden_states else None
if output_hidden_states:
all_hidden_states = (hidden_states,)
for layer_idx, encoder_layer in enumerate(self.layers):
layer_outputs = encoder_layer(
hidden_states,
cu_seqlens,
)
hidden_states = layer_outputs[0]
if output_hidden_states:
all_hidden_states += (hidden_states,)
hidden_states = self.ln_post(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
)
def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"):
"""
Pads a sequence of tensors to their maximum length on indicated `padding_side`.
Then prepares a mask so that pad tokens are not attended to.
"""
max_len = tensor_len.max()
dim = tensor_list[0].shape[0]
padded_tensor = torch.full(
size=(len(tensor_list), dim, max_len),
fill_value=padding_value,
dtype=self.dtype,
device=tensor_list[0].device,
)
batch_mask = torch.zeros(
(len(tensor_len), max_len),
dtype=torch.long,
device=padded_tensor.device,
)
for i, length in enumerate(tensor_len):
batch_mask[i, :length] = 1
padded_tensor[i, :, :length] = tensor_list[i]
feature_lens_after_cnn = (tensor_len - 1) // 2 + 1
max_len_after_cnn = feature_lens_after_cnn.max()
batch_mask_after_cnn = torch.zeros(
(len(tensor_len), max_len_after_cnn),
dtype=torch.long,
device=padded_tensor.device,
)
for i, length in enumerate(feature_lens_after_cnn):
batch_mask_after_cnn[i, :length] = 1
return (
padded_tensor,
batch_mask.unsqueeze(1),
batch_mask_after_cnn.bool(),
)
class Adapter(nn.Module):
def __init__(
self,
d_model: int,
n_embd: int,
):
super().__init__()
self.audio_projector = torch.nn.Sequential(
torch.nn.Linear(d_model, n_embd),
torch.nn.GELU(),
torch.nn.Linear(n_embd, n_embd)
)
def forward(self, x: Tensor) -> Tensor:
x = self.audio_projector(x)
return x
class UASAudioForCausalLM(PreTrainedModel, GenerationMixin):
config_class = UASAudioConfig
main_input_name = "input_ids"
supports_gradient_checkpointing = True
def __init__(self, config: UASAudioConfig):
super().__init__(config)
if isinstance(config.dtype, str):
dtype = getattr(torch, config.dtype)
else:
dtype = config.dtype
self.bf16 = dtype == torch.bfloat16
self.llm = Qwen2ForCausalLM(config.text_config)
self.audio_encoder = UASAudioEncoder(config.audio_encoder_config)
d_model = config.audio_encoder_config.d_model
self.adapter = Adapter(
d_model,
config.text_config.hidden_size,
)
self.audio_token = config.audio_token
if self.bf16:
self.audio_encoder = self.audio_encoder.bfloat16()
self.adapter = self.adapter.bfloat16()
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
mels=None,
mel_masks=None,
past_key_values=None,
**kwargs
):
# If past_key_values are provided, we are in the generation phase and should not process audio inputs again
if past_key_values is not None:
outputs = self.llm(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
**kwargs
)
else:
# First get the text embeddings for the input_ids, then replace audio token positions with audio embeddings
hidden_states = self.embedding_with_audio_tokens(input_ids, mels, mel_masks)
outputs = self.llm(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
past_key_values=None,
**kwargs
)
return outputs
def embedding_with_audio_tokens(
self,
input_ids,
mels,
mel_masks
):
"""
Get input embeddings for the LLM, replacing audio token positions with audio features from the audio encoder.
"""
hidden_states = self.embeddings(input_ids)
if mels is None:
return hidden_states
audio_embeddings = self.audio_encoding(mels, mel_masks) # data -> feature
audio_embeddings = self.adapter(audio_embeddings)
audio_mask = input_ids == self.audio_token
hidden_states[audio_mask] = audio_embeddings
return hidden_states
def audio_encoding(
self,
audio_features: torch.Tensor,
audio_features_mask: torch.Tensor,
output_hidden_states: bool = False
):
"""
Encode audio features into embeddings.
Args:
audio_features: Audio features tensor
audio_features_mask: Audio features mask
output_hidden_states: Whether to return hidden states from all encoder layers
Returns:
If output_hidden_states=False: audio_features_encoded tensor
If output_hidden_states=True: BaseModelOutput with last_hidden_state and hidden_states
"""
feature_lens = audio_features_mask.sum(-1).long() # [batch_size]
input_features = audio_features.permute(0, 2, 1)[audio_features_mask.bool()].permute(1, 0)
audio_encoder_outputs = self.audio_encoder(
input_features,
feature_lens=feature_lens,
output_hidden_states=output_hidden_states,
return_dict=output_hidden_states, # Only return dict when we need hidden states
)
if output_hidden_states:
# When output_hidden_states=True, we get BaseModelOutput
return audio_encoder_outputs
else:
# When output_hidden_states=False, we get tuple (hidden_states, ...)
# Extract the first element (hidden_states tensor) for backward compatibility
if isinstance(audio_encoder_outputs, tuple):
return audio_encoder_outputs[0]
return audio_encoder_outputs
@property
def embeddings(self):
"""Return the model's input embeddings - required for GenerationMixin"""
return self.llm.model.embed_tokens
def forward_with_detailed_outputs(
self,
input_ids=None,
attention_mask=None,
mels=None,
mel_masks=None,
past_key_values=None,
output_hidden_states: bool = True,
**kwargs
):
"""
Forward pass that returns detailed outputs including:
- Audio encoder final output
- Audio features after projector (adapter)
- Text embedding features
- Hidden states from each layer (separated for audio and text)
Args:
input_ids: Input token ids
attention_mask: Attention mask
mels: Audio mel features
mel_masks: Audio mel masks
past_key_values: Past key values for generation
output_hidden_states: Whether to return hidden states from all layers
**kwargs: Additional arguments
Returns:
dict containing:
- audio_encoder_output: Final output from audio encoder
- audio_features_after_adapter: Audio features after projector/adapter
- text_embeddings: Text embedding features (before audio replacement)
- audio_encoder_hidden_states: Tuple of hidden states from each audio encoder layer
- llm_hidden_states: Tuple of hidden states from each LLM layer (mixed audio+text)
- llm_hidden_states_text_only: Tuple of text-only hidden states from each LLM layer
- llm_hidden_states_audio_only: Tuple of audio-only hidden states from each LLM layer
- llm_outputs: Full LLM outputs (CausalLMOutputWithPast)
"""
# Get text embeddings (pure text, before audio replacement)
# Save original text embeddings for return (will have audio parts removed)
text_embeddings_pure = self.embeddings(input_ids)
# Process audio if provided
audio_encoder_output = None
audio_features_after_adapter = None
audio_encoder_hidden_states = None
audio_mask = None
# Create embeddings for LLM forward pass (may include audio features)
input_embeddings_for_llm = text_embeddings_pure.clone()
# Identify audio token positions (even if no audio is provided, audio tokens may exist in input_ids)
audio_mask = input_ids == self.audio_token
if mels is not None:
# Get audio encoder outputs with hidden states
audio_encoder_outputs = self.audio_encoding(
mels,
mel_masks,
output_hidden_states=output_hidden_states
)
if output_hidden_states:
audio_encoder_output = audio_encoder_outputs.last_hidden_state
audio_encoder_hidden_states = audio_encoder_outputs.hidden_states
else:
audio_encoder_output = audio_encoder_outputs
audio_encoder_hidden_states = None
# Apply adapter
audio_features_after_adapter = self.adapter(audio_encoder_output)
# Replace audio token positions with audio embeddings in the LLM input
input_embeddings_for_llm[audio_mask] = audio_features_after_adapter
# Remove audio parts from text_embeddings_pure (delete, not set to zero)
# This ensures returned text_embeddings strictly contains no audio features
if audio_mask.any():
# Process each batch separately since audio positions may differ
batch_size = text_embeddings_pure.shape[0]
text_embeddings_list = []
for i in range(batch_size):
# Get text-only mask for this batch (inverse of audio_mask)
text_mask = ~audio_mask[i] # shape: (seq_len,)
# Extract only text embeddings
text_emb = text_embeddings_pure[i][text_mask] # shape: (text_seq_len, hidden_size)
text_embeddings_list.append(text_emb)
# Pad sequences to the same length for batching
# Use the maximum text sequence length across batches
max_text_len = max(emb.shape[0] for emb in text_embeddings_list) if text_embeddings_list else 0
if max_text_len > 0:
hidden_size = text_embeddings_pure.shape[2]
text_embeddings_pure = torch.zeros(
(batch_size, max_text_len, hidden_size),
dtype=text_embeddings_pure.dtype,
device=text_embeddings_pure.device
)
for i, emb in enumerate(text_embeddings_list):
text_len = emb.shape[0]
text_embeddings_pure[i, :text_len] = emb
else:
# No text embeddings (all are audio tokens)
hidden_size = text_embeddings_pure.shape[2]
text_embeddings_pure = torch.zeros(
(batch_size, 0, hidden_size),
dtype=text_embeddings_pure.dtype,
device=text_embeddings_pure.device
)
# If no audio tokens, text_embeddings_pure remains unchanged
# Forward through LLM
if past_key_values is not None:
# Incremental decoding
llm_outputs = self.llm(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
output_hidden_states=output_hidden_states,
return_dict=True,
**kwargs
)
else:
# First step: use combined embeddings (may include audio features)
llm_outputs = self.llm(
inputs_embeds=input_embeddings_for_llm,
attention_mask=attention_mask,
past_key_values=None,
output_hidden_states=output_hidden_states,
return_dict=True,
**kwargs
)
# Extract LLM hidden states
llm_hidden_states = llm_outputs.hidden_states if output_hidden_states else None
# Separate audio and text hidden states if audio is present
llm_hidden_states_text_only = None
llm_hidden_states_audio_only = None
if output_hidden_states and llm_hidden_states is not None and audio_mask is not None:
# Separate each layer's hidden states into text and audio parts
llm_hidden_states_text_only = tuple()
llm_hidden_states_audio_only = tuple()
batch_size = llm_hidden_states[0].shape[0]
hidden_size = llm_hidden_states[0].shape[2]
for layer_hidden_states in llm_hidden_states:
# layer_hidden_states shape: (batch_size, seq_len, hidden_size)
# audio_mask shape: (batch_size, seq_len)
# Process text-only hidden states: delete audio positions, not set to zero
text_hidden_list = []
audio_hidden_list = []
for i in range(batch_size):
# Get text-only mask for this batch (inverse of audio_mask)
text_mask = ~audio_mask[i] # shape: (seq_len,)
audio_mask_i = audio_mask[i] # shape: (seq_len,)
# Extract only text hidden states (delete audio positions)
text_hidden_i = layer_hidden_states[i][text_mask] # shape: (text_seq_len, hidden_size)
text_hidden_list.append(text_hidden_i)
# Extract only audio hidden states (delete text positions)
audio_hidden_i = layer_hidden_states[i][audio_mask_i] # shape: (audio_seq_len, hidden_size)
audio_hidden_list.append(audio_hidden_i)
# Pad sequences to the same length for batching
# Use the maximum text sequence length across batches
max_text_len = max(emb.shape[0] for emb in text_hidden_list) if text_hidden_list else 0
max_audio_len = max(emb.shape[0] for emb in audio_hidden_list) if audio_hidden_list else 0
if max_text_len > 0:
text_hidden = torch.zeros(
(batch_size, max_text_len, hidden_size),
dtype=layer_hidden_states.dtype,
device=layer_hidden_states.device
)
for i, emb in enumerate(text_hidden_list):
text_len = emb.shape[0]
text_hidden[i, :text_len] = emb
else:
# No text hidden states (all are audio tokens)
text_hidden = torch.zeros(
(batch_size, 0, hidden_size),
dtype=layer_hidden_states.dtype,
device=layer_hidden_states.device
)
if max_audio_len > 0:
audio_hidden = torch.zeros(
(batch_size, max_audio_len, hidden_size),
dtype=layer_hidden_states.dtype,
device=layer_hidden_states.device
)
for i, emb in enumerate(audio_hidden_list):
audio_len = emb.shape[0]
audio_hidden[i, :audio_len] = emb
else:
# No audio hidden states (all are text tokens)
audio_hidden = torch.zeros(
(batch_size, 0, hidden_size),
dtype=layer_hidden_states.dtype,
device=layer_hidden_states.device
)
llm_hidden_states_text_only += (text_hidden,)
llm_hidden_states_audio_only += (audio_hidden,)
return {
"audio_encoder_output": audio_encoder_output,
"audio_features_after_adapter": audio_features_after_adapter,
"text_embeddings": text_embeddings_pure, # Return text embeddings with audio parts removed
"audio_encoder_hidden_states": audio_encoder_hidden_states,
"llm_hidden_states": llm_hidden_states,
"llm_hidden_states_text_only": llm_hidden_states_text_only,
"llm_hidden_states_audio_only": llm_hidden_states_audio_only,
"llm_outputs": llm_outputs,
}
def generate(
self,
input_ids,
attention_mask=None,
mels=None,
mel_masks=None,
generation_config=None,
**generate_kwargs
):
"""
New implementation of the generate method to support audio inputs.
This method will:
1. Handle the initial processing of audio inputs;
2. Call the underlying LLM's generate method with the appropriate embeddings;
3. The incremental decoding will be handled by the LLM's generate method using past_key_values.
"""
# Process audio inputs and get combined embeddings for the initial step
input_embeddings = self.embedding_with_audio_tokens(input_ids, mels, mel_masks)
# Call the underlying LLM's generate method with inputs_embeds instead of input_ids
# The LLM's generate method will handle the generation loop.
# During incremental decoding, it will use past_key_values to avoid re-processing audio inputs.
outputs = self.llm.generate(
inputs_embeds=input_embeddings,
attention_mask=attention_mask,
generation_config=generation_config,
use_cache=True,
**generate_kwargs
)
return outputs
class UASAudioEncoderOnly(PreTrainedModel):
"""
UASAudio encoder-only model that contains only the audio encoder and adapter.
Input: audio features
Output: features processed by encoder and adapter
"""
config_class = UASAudioEncoderOnlyConfig
main_input_name = "input_features"
input_modalities = "audio"
def __init__(self, config: UASAudioEncoderOnlyConfig):
super().__init__(config)
if isinstance(config.dtype, str):
dtype = getattr(torch, config.dtype)
else:
dtype = getattr(config, "dtype", torch.bfloat16)
self.bf16 = dtype == torch.bfloat16
self.audio_encoder = UASAudioEncoder(config.audio_encoder_config)
d_model = config.audio_encoder_config.d_model
self.adapter = Adapter(
d_model,
config.hidden_size,
)
if self.bf16:
self.audio_encoder = self.audio_encoder.bfloat16()
self.adapter = self.adapter.bfloat16()
self.post_init()
def forward(
self,
input_features: torch.Tensor,
feature_lens: Optional[torch.Tensor] = None,
**kwargs
):
"""
Forward pass through audio encoder and adapter.
Args:
input_features: Audio features tensor of shape (seq_len, num_mel_bins) or (batch, seq_len, num_mel_bins)
feature_lens: Optional tensor of shape (batch_size,) indicating the length of each sequence
**kwargs: Additional arguments passed to audio encoder
Returns:
torch.Tensor: Features processed by encoder and adapter
"""
# Encode audio features
audio_features_encoded = self.audio_encoder(
input_features,
feature_lens=feature_lens,
**kwargs
)
# Handle tuple output from encoder (backward compatibility)
if isinstance(audio_features_encoded, tuple):
audio_features_encoded = audio_features_encoded[0]
elif hasattr(audio_features_encoded, "last_hidden_state"):
audio_features_encoded = audio_features_encoded.last_hidden_state
# Apply adapter (projector)
output_features = self.adapter(audio_features_encoded)
return output_features