Output attentions

#3
by urroxyz - opened

I've discovered that it is possible to extract word-level timestamps from Voxtral's PyTorch model.

Unfortunately, the ONNX conversion has been exported with output_attentions=False (something that should always be True in my opinion).

@Xenova , is there any chance you could upload an attentive conversion?

In the meantime, I will attempt my own.

I've discovered that it is possible to extract word-level timestamps from Voxtral's PyTorch model.

Unfortunately, the ONNX conversion has been exported with output_attentions=False (something that should always be True in my opinion).

@Xenova , is there any chance you could upload an attentive conversion?

In the meantime, I will attempt my own.

HI, I would like to ask how you enable word-level timestamp in Voxtral

Of course! I've just uploaded a version of the model that makes this possible.

I re-exported the ONNX conversion with both attentions and positions exposed. It's available here. You'll need to download both decoder files from this repo.

Once you've done that, you can run the following Python cell for CPU inference (as long as you also have an audio.wav in the same directory, plus ipython installed):

import os
import numpy as np
import onnxruntime as ort
from huggingface_hub import snapshot_download
from tokenizers import Tokenizer
import soundfile as sf
import librosa
import logging
import matplotlib.pyplot as plt
from IPython.display import display, Audio
import torch # Still needed for attention processing & DTW

# --- Helper function to create the 4D mask ---
def _create_4d_causal_attention_mask(input_shape, past_sequence_length, dtype=np.float16):
    batch_size, sequence_length = input_shape
    total_sequence_length = past_sequence_length + sequence_length
    
    mask = np.tril(np.ones((total_sequence_length, total_sequence_length), dtype=np.bool_))
    mask = mask[past_sequence_length:, :]
    
    causal_mask = np.zeros((batch_size, 1, sequence_length, total_sequence_length), dtype=dtype)
    causal_mask[:, :, :, :] = np.where(
        mask[None, None, :, :], 0.0, np.finfo(dtype).min
    )
    return causal_mask

# ==============================================================================
# PART 1: SETUP AND MODEL LOADING
# ==============================================================================

# --- Configuration ---
repo_id = "onnx-community/Voxtral-Mini-3B-2507-ONNX"
audio_file_path = "audio.wav"
custom_decoder_path = "decoder_with_attentions.onnx"
max_generation_tokens = 999
eos_token_id = 2

# --- Download, Tokenizer, Sessions ---
print(f"Downloading base model files from {repo_id}...")
local_dir = snapshot_download(
    repo_id=repo_id,
    repo_type="model",
    allow_patterns=["onnx/audio_encoder_q4.onnx", "onnx/embed_tokens_q4.onnx", "tokenizer.json"],
)
onnx_dir = os.path.join(local_dir, "onnx")
print("\nLoading tokenizer...")
tok = Tokenizer.from_file(os.path.join(local_dir, "tokenizer.json"))
bos_id, inst_id, baud_id, aud_id, einst_id = 1, 3, 25, 24, 4
print("\nInitializing ONNX Runtime sessions...")
ae_path = os.path.join(onnx_dir, "audio_encoder_q4.onnx")
embed_path = os.path.join(onnx_dir, "embed_tokens_q4.onnx")
if not os.path.exists(custom_decoder_path):
    raise FileNotFoundError(f"Custom ONNX decoder not found at '{custom_decoder_path}'.")
sess_opts = ort.SessionOptions()
session_providers = ["CPUExecutionProvider"]
ae_sess = ort.InferenceSession(ae_path, sess_options=sess_opts, providers=session_providers)
embed_sess = ort.InferenceSession(embed_path, sess_options=sess_opts, providers=session_providers)
dec_sess = ort.InferenceSession(custom_decoder_path, sess_options=sess_opts, providers=session_providers)
print("Sessions initialized.")
num_decoder_layers = sum(1 for i in dec_sess.get_inputs() if i.name.endswith(".key"))
print(f"Detected {num_decoder_layers} decoder layers.")

# --- Audio Pre-processing ---
def extract_mel_features_for_chunk(audio_chunk, sampling_rate=16000, n_fft=400, hop_length=160, n_mels=128, target_length=3000):
    target_samples = sampling_rate * 30
    audio_chunk = librosa.util.fix_length(audio_chunk, size=target_samples)
    mel_spec = librosa.feature.melspectrogram(y=audio_chunk, sr=sampling_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
    log_spec = np.log10(np.maximum(mel_spec, 1e-10))
    log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
    log_spec = (log_spec + 4.0) / 4.0
    return log_spec[:, :target_length].astype(np.float32)

def process_long_audio(audio_path, session, sampling_rate=16000):
    y, sr = sf.read(audio_path)
    if y.ndim > 1: y = y.mean(axis=1)
    if sr != sampling_rate: y = librosa.resample(y, orig_sr=sr, target_sr=sampling_rate)
    chunk_samples = chunk_duration = 30 * sampling_rate
    num_chunks = int(np.ceil(len(y) / chunk_samples))
    all_audio_embeds = []
    print(f"Processing in {num_chunks} chunk(s)...")
    for i in range(num_chunks):
        chunk = y[i * chunk_samples:(i + 1) * chunk_samples]
        mel_features = extract_mel_features_for_chunk(chunk, sampling_rate)
        all_audio_embeds.append(session.run(None, {session.get_inputs()[0].name: mel_features[None, :]})[0])
    return np.concatenate(all_audio_embeds, axis=0)

# ==============================================================================
# PART 2: GENERATION PASS
# ==============================================================================
print("\n--- Starting Transcription Generation Pass ---")
if not os.path.exists(audio_file_path): raise FileNotFoundError(f"Audio file '{audio_file_path}' not found.")

audio_embeds_raw = process_long_audio(audio_file_path, ae_sess)
batch_size = 1
audio_output_frames = audio_embeds_raw.shape[0] // batch_size
audio_embeds = audio_embeds_raw.reshape(batch_size, audio_output_frames, -1)
text_instruction_ids = tok.encode("Transcribe the audio.", add_special_tokens=False).ids
prompt_tokens = ([bos_id, inst_id, baud_id] + [aud_id] * audio_output_frames + text_instruction_ids + [einst_id])
initial_sequence_length = len(prompt_tokens)

# Create initial embeddings
prompt_ids = np.array([prompt_tokens], dtype=np.int64)
inputs_embeds = embed_sess.run(None, {"input_ids": prompt_ids})[0]
inputs_embeds[0, 3:3 + audio_output_frames, :] = audio_embeds[0]
inputs_embeds = inputs_embeds.astype(np.float16)

# Generation loop
generated_ids = []
past_key_values = None
current_past_len = 0
print("\nGenerating text...")
for i in range(max_generation_tokens):
    dec_inputs = {}
    if i == 0:
        dec_inputs["inputs_embeds"] = inputs_embeds
        attention_mask = _create_4d_causal_attention_mask((batch_size, initial_sequence_length), 0)
    else:
        last_token_id = np.array([[generated_ids[-1]]], dtype=np.int64)
        dec_inputs["inputs_embeds"] = embed_sess.run(None, {"input_ids": last_token_id})[0].astype(np.float16)
        attention_mask = _create_4d_causal_attention_mask((batch_size, 1), current_past_len)

    dec_inputs["attention_mask"] = attention_mask
    if past_key_values:
        for l in range(num_decoder_layers):
            dec_inputs[f"past_key_values.{l}.key"] = past_key_values[l*2]
            dec_inputs[f"past_key_values.{l}.value"] = past_key_values[l*2+1]
    else: # First pass dummy past
        for l in range(num_decoder_layers):
            dec_inputs[f"past_key_values.{l}.key"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float16)
            dec_inputs[f"past_key_values.{l}.value"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float16)

    outputs = dec_sess.run(None, dec_inputs)
    logits, past_key_values = outputs[0], outputs[1:1+num_decoder_layers*2]
    
    next_token_id = np.argmax(logits[0, -1, :])
    if next_token_id == eos_token_id: break
    generated_ids.append(next_token_id)
    print(tok.decode([next_token_id]), end="", flush=True)
    current_past_len += dec_inputs["inputs_embeds"].shape[1]
print("\n\n--- Transcription Complete ---")

# ==============================================================================
# PART 3: ALIGNMENT PASS AND TIMESTAMP EXTRACTION
# ==============================================================================
print("\n--- Starting Word-Level Timestamp Alignment ---")

# --- 1. Prepare inputs for the single alignment pass ---
print("Preparing inputs for alignment pass...")
full_sequence_ids = np.array([prompt_tokens + generated_ids], dtype=np.int64)
full_embeds = embed_sess.run(None, {"input_ids": full_sequence_ids})[0]
full_embeds[0, 3:3 + audio_output_frames, :] = audio_embeds[0]
full_embeds = full_embeds.astype(np.float16)

alignment_inputs = {
    "inputs_embeds": full_embeds,
    "attention_mask": _create_4d_causal_attention_mask(full_embeds.shape[:2], 0)
}
for l in range(num_decoder_layers):
    alignment_inputs[f"past_key_values.{l}.key"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float16)
    alignment_inputs[f"past_key_values.{l}.value"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float16)

# --- 2. Run the alignment pass ---
print("Running forward pass to extract attention weights...")
alignment_outputs = dec_sess.run(None, alignment_inputs)
attentions = [torch.from_numpy(attn) for attn in alignment_outputs[1+num_decoder_layers*2:]]
print("Attention weights extracted.")

# --- 3. Process attentions ---
print("Processing attention weights for alignment...")
text_start_idx = len(prompt_tokens)
audio_end_idx = 3 + audio_output_frames
start_layer, end_layer = 10, 20
layer_attentions = []
for i in range(start_layer, end_layer):
    layer_attn = attentions[i][0] 
    layer_attn_avg_heads = layer_attn.mean(dim=0)
    relevant_attns = layer_attn_avg_heads[text_start_idx:, 3:audio_end_idx]
    if relevant_attns.numel() > 0:
        layer_attentions.append(relevant_attns)

if not layer_attentions:
    raise ValueError("Could not extract any valid attention weights. The generated text might be empty.")

avg_attentions = torch.stack(layer_attentions).mean(dim=0)
temperature = 0.1
weights = torch.nn.functional.softmax(avg_attentions / temperature, dim=1).cpu().numpy()

plt.figure(figsize=(10, 10))
plt.imshow(weights, aspect="auto", origin="lower", cmap="viridis")
plt.xlabel("Audio Frames")
plt.ylabel("Generated Text Tokens")
plt.title("Audio-to-Text Alignment Matrix (Sharpened)")
plt.colorbar()
plt.savefig("alignment_matrix.png")
print("Saved alignment matrix visualization to alignment_matrix.png")

# --- 4. DTW and Timestamp calculation ---
print("Performing DTW and mapping to timestamps...")
cost_matrix = -weights.T
D, wp = librosa.sequence.dtw(C=cost_matrix.astype(np.float32), backtrack=True)
wp = np.flip(wp, axis=0)

# *** THE FIX IS HERE ***
token_to_frame_map = {}
# Create a map of the first frame seen for each token index
for frame_idx, token_idx in wp:
    if token_idx not in token_to_frame_map:
        token_to_frame_map[token_idx] = frame_idx

word_groups = []
current_word_tokens = []
if generated_ids:
    for token_id in generated_ids:
        if tok.decode([token_id]).startswith(" ") and current_word_tokens:
            word_groups.append(current_word_tokens)
            current_word_tokens = []
        current_word_tokens.append(token_id)
    if current_word_tokens: word_groups.append(current_word_tokens)

EFFECTIVE_AUDIO_DURATION = 30.0
AUDIO_TIME_PER_FRAME = EFFECTIVE_AUDIO_DURATION / audio_output_frames
results = []
token_idx_counter = 0
# Initialize start time using the very first aligned token
previous_word_end_time = token_to_frame_map.get(0, 0) * AUDIO_TIME_PER_FRAME

for word_group in word_groups:
    word_text = tok.decode(word_group).strip()
    if not word_text: continue
    
    start_time = previous_word_end_time
    last_token_in_word_idx = token_idx_counter + len(word_group) - 1
    end_frame = token_to_frame_map.get(last_token_in_word_idx, 0)
    end_time = max(start_time, end_frame * AUDIO_TIME_PER_FRAME)
    results.append({"word": word_text, "start": start_time, "end": end_time})
    previous_word_end_time = end_time
    token_idx_counter += len(word_group)

print("\n--- Transcription with Word-Level Timestamps ---")
for res in results: print(f"[{res['start']: >6.2f}s -> {res['end']: >6.2f}s] {res['word']}")
print("\n--- Timestamp Extraction Complete ---")

# ==============================================================================
# PART 4: VERIFICATION
# ==============================================================================
print("\n--- Verifying Alignment: Playing Audio Snippets ---")
SAMPLING_RATE = 16000
y, sr = librosa.load(audio_file_path, sr=SAMPLING_RATE)
if not results: print("No words were transcribed to verify.")
else:
    for res in results:
        start_sample = int(res['start'] * SAMPLING_RATE)
        end_sample = int(res['end'] * SAMPLING_RATE)
        audio_snippet = y[start_sample:end_sample]
        print(f"\n[{res['start']: >6.2f}s -> {res['end']: >6.2f}s] {res['word']}")
        if len(audio_snippet) > 0: display(Audio(audio_snippet, rate=SAMPLING_RATE))
        else: print("   (No audio for this segment)")
print("\n--- Verification Complete ---")

Let me know if you have any issues. I'll upload some quantizations soon.

ONNX Community org

That's really great @urroxyz work! I'd also be interested in seeing how your export compares in performance to the un-timestamped one. Here's a comparison of the graphs:

We should be able to enable this in our optimized export too by adding an intermediate node as output. If you have some time to look into that, I'd greatly appreciate it!

I can definitely look into that.

Right now, I'm just attempting to quantize the model. I couldn't figure out how to do it via the CLI, so I'm tinkering around.

Using the ONNX library is a struggle for me.

ONNX Community org

I can definitely look into that.

The quickest way (i.e., without exporting) would be to use the onnx library directly and modify the necessary nodes. I can try outline what I'd do if it can help you:

  1. Following the docs, https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftgroupqueryattention

qk_output : int
Output values of QK matrix multiplication before (1) or after (2) softmax normalization. Default value is 0 (don't output).

Modify each GroupQueryAttention node to use a value of 2 (assuming you want normalized scores) for qk_output

  1. You can then use the output_qk value

output_qk (optional) : T
Values of QK matrix multiplication, either before or after softmax normalization

  1. Add each of these output nodes to the list of model outputs, graph.outputs

@Xenova , I accidentally saved my ONNX conversion without the onnx_data file.

Because of this, no one else will be able to run it. Trying to fix that now. (Edit: Fixed!)

If you could also create your own version (attentions/positions enabled, alongside quants), considering you're much more experienced in ONNX, I would greatly appreciate it. Or if you could check the code on the model README.md to see if I did well.

Sign up or log in to comment