Output attentions
I've discovered that it is possible to extract word-level timestamps from Voxtral's PyTorch model.
Unfortunately, the ONNX conversion has been exported with output_attentions=False
(something that should always be True
in my opinion).
@Xenova , is there any chance you could upload an attentive conversion?
In the meantime, I will attempt my own.
I've discovered that it is possible to extract word-level timestamps from Voxtral's PyTorch model.
Unfortunately, the ONNX conversion has been exported with
output_attentions=False
(something that should always beTrue
in my opinion).@Xenova , is there any chance you could upload an attentive conversion?
In the meantime, I will attempt my own.
HI, I would like to ask how you enable word-level timestamp in Voxtral
Of course! I've just uploaded a version of the model that makes this possible.
I re-exported the ONNX conversion with both attentions and positions exposed. It's available here. You'll need to download both decoder files from this repo.
Once you've done that, you can run the following Python cell for CPU inference (as long as you also have an audio.wav
in the same directory, plus ipython
installed):
import os
import numpy as np
import onnxruntime as ort
from huggingface_hub import snapshot_download
from tokenizers import Tokenizer
import soundfile as sf
import librosa
import logging
import matplotlib.pyplot as plt
from IPython.display import display, Audio
import torch # Still needed for attention processing & DTW
# --- Helper function to create the 4D mask ---
def _create_4d_causal_attention_mask(input_shape, past_sequence_length, dtype=np.float16):
batch_size, sequence_length = input_shape
total_sequence_length = past_sequence_length + sequence_length
mask = np.tril(np.ones((total_sequence_length, total_sequence_length), dtype=np.bool_))
mask = mask[past_sequence_length:, :]
causal_mask = np.zeros((batch_size, 1, sequence_length, total_sequence_length), dtype=dtype)
causal_mask[:, :, :, :] = np.where(
mask[None, None, :, :], 0.0, np.finfo(dtype).min
)
return causal_mask
# ==============================================================================
# PART 1: SETUP AND MODEL LOADING
# ==============================================================================
# --- Configuration ---
repo_id = "onnx-community/Voxtral-Mini-3B-2507-ONNX"
audio_file_path = "audio.wav"
custom_decoder_path = "decoder_with_attentions.onnx"
max_generation_tokens = 999
eos_token_id = 2
# --- Download, Tokenizer, Sessions ---
print(f"Downloading base model files from {repo_id}...")
local_dir = snapshot_download(
repo_id=repo_id,
repo_type="model",
allow_patterns=["onnx/audio_encoder_q4.onnx", "onnx/embed_tokens_q4.onnx", "tokenizer.json"],
)
onnx_dir = os.path.join(local_dir, "onnx")
print("\nLoading tokenizer...")
tok = Tokenizer.from_file(os.path.join(local_dir, "tokenizer.json"))
bos_id, inst_id, baud_id, aud_id, einst_id = 1, 3, 25, 24, 4
print("\nInitializing ONNX Runtime sessions...")
ae_path = os.path.join(onnx_dir, "audio_encoder_q4.onnx")
embed_path = os.path.join(onnx_dir, "embed_tokens_q4.onnx")
if not os.path.exists(custom_decoder_path):
raise FileNotFoundError(f"Custom ONNX decoder not found at '{custom_decoder_path}'.")
sess_opts = ort.SessionOptions()
session_providers = ["CPUExecutionProvider"]
ae_sess = ort.InferenceSession(ae_path, sess_options=sess_opts, providers=session_providers)
embed_sess = ort.InferenceSession(embed_path, sess_options=sess_opts, providers=session_providers)
dec_sess = ort.InferenceSession(custom_decoder_path, sess_options=sess_opts, providers=session_providers)
print("Sessions initialized.")
num_decoder_layers = sum(1 for i in dec_sess.get_inputs() if i.name.endswith(".key"))
print(f"Detected {num_decoder_layers} decoder layers.")
# --- Audio Pre-processing ---
def extract_mel_features_for_chunk(audio_chunk, sampling_rate=16000, n_fft=400, hop_length=160, n_mels=128, target_length=3000):
target_samples = sampling_rate * 30
audio_chunk = librosa.util.fix_length(audio_chunk, size=target_samples)
mel_spec = librosa.feature.melspectrogram(y=audio_chunk, sr=sampling_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
log_spec = np.log10(np.maximum(mel_spec, 1e-10))
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec[:, :target_length].astype(np.float32)
def process_long_audio(audio_path, session, sampling_rate=16000):
y, sr = sf.read(audio_path)
if y.ndim > 1: y = y.mean(axis=1)
if sr != sampling_rate: y = librosa.resample(y, orig_sr=sr, target_sr=sampling_rate)
chunk_samples = chunk_duration = 30 * sampling_rate
num_chunks = int(np.ceil(len(y) / chunk_samples))
all_audio_embeds = []
print(f"Processing in {num_chunks} chunk(s)...")
for i in range(num_chunks):
chunk = y[i * chunk_samples:(i + 1) * chunk_samples]
mel_features = extract_mel_features_for_chunk(chunk, sampling_rate)
all_audio_embeds.append(session.run(None, {session.get_inputs()[0].name: mel_features[None, :]})[0])
return np.concatenate(all_audio_embeds, axis=0)
# ==============================================================================
# PART 2: GENERATION PASS
# ==============================================================================
print("\n--- Starting Transcription Generation Pass ---")
if not os.path.exists(audio_file_path): raise FileNotFoundError(f"Audio file '{audio_file_path}' not found.")
audio_embeds_raw = process_long_audio(audio_file_path, ae_sess)
batch_size = 1
audio_output_frames = audio_embeds_raw.shape[0] // batch_size
audio_embeds = audio_embeds_raw.reshape(batch_size, audio_output_frames, -1)
text_instruction_ids = tok.encode("Transcribe the audio.", add_special_tokens=False).ids
prompt_tokens = ([bos_id, inst_id, baud_id] + [aud_id] * audio_output_frames + text_instruction_ids + [einst_id])
initial_sequence_length = len(prompt_tokens)
# Create initial embeddings
prompt_ids = np.array([prompt_tokens], dtype=np.int64)
inputs_embeds = embed_sess.run(None, {"input_ids": prompt_ids})[0]
inputs_embeds[0, 3:3 + audio_output_frames, :] = audio_embeds[0]
inputs_embeds = inputs_embeds.astype(np.float16)
# Generation loop
generated_ids = []
past_key_values = None
current_past_len = 0
print("\nGenerating text...")
for i in range(max_generation_tokens):
dec_inputs = {}
if i == 0:
dec_inputs["inputs_embeds"] = inputs_embeds
attention_mask = _create_4d_causal_attention_mask((batch_size, initial_sequence_length), 0)
else:
last_token_id = np.array([[generated_ids[-1]]], dtype=np.int64)
dec_inputs["inputs_embeds"] = embed_sess.run(None, {"input_ids": last_token_id})[0].astype(np.float16)
attention_mask = _create_4d_causal_attention_mask((batch_size, 1), current_past_len)
dec_inputs["attention_mask"] = attention_mask
if past_key_values:
for l in range(num_decoder_layers):
dec_inputs[f"past_key_values.{l}.key"] = past_key_values[l*2]
dec_inputs[f"past_key_values.{l}.value"] = past_key_values[l*2+1]
else: # First pass dummy past
for l in range(num_decoder_layers):
dec_inputs[f"past_key_values.{l}.key"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float16)
dec_inputs[f"past_key_values.{l}.value"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float16)
outputs = dec_sess.run(None, dec_inputs)
logits, past_key_values = outputs[0], outputs[1:1+num_decoder_layers*2]
next_token_id = np.argmax(logits[0, -1, :])
if next_token_id == eos_token_id: break
generated_ids.append(next_token_id)
print(tok.decode([next_token_id]), end="", flush=True)
current_past_len += dec_inputs["inputs_embeds"].shape[1]
print("\n\n--- Transcription Complete ---")
# ==============================================================================
# PART 3: ALIGNMENT PASS AND TIMESTAMP EXTRACTION
# ==============================================================================
print("\n--- Starting Word-Level Timestamp Alignment ---")
# --- 1. Prepare inputs for the single alignment pass ---
print("Preparing inputs for alignment pass...")
full_sequence_ids = np.array([prompt_tokens + generated_ids], dtype=np.int64)
full_embeds = embed_sess.run(None, {"input_ids": full_sequence_ids})[0]
full_embeds[0, 3:3 + audio_output_frames, :] = audio_embeds[0]
full_embeds = full_embeds.astype(np.float16)
alignment_inputs = {
"inputs_embeds": full_embeds,
"attention_mask": _create_4d_causal_attention_mask(full_embeds.shape[:2], 0)
}
for l in range(num_decoder_layers):
alignment_inputs[f"past_key_values.{l}.key"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float16)
alignment_inputs[f"past_key_values.{l}.value"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float16)
# --- 2. Run the alignment pass ---
print("Running forward pass to extract attention weights...")
alignment_outputs = dec_sess.run(None, alignment_inputs)
attentions = [torch.from_numpy(attn) for attn in alignment_outputs[1+num_decoder_layers*2:]]
print("Attention weights extracted.")
# --- 3. Process attentions ---
print("Processing attention weights for alignment...")
text_start_idx = len(prompt_tokens)
audio_end_idx = 3 + audio_output_frames
start_layer, end_layer = 10, 20
layer_attentions = []
for i in range(start_layer, end_layer):
layer_attn = attentions[i][0]
layer_attn_avg_heads = layer_attn.mean(dim=0)
relevant_attns = layer_attn_avg_heads[text_start_idx:, 3:audio_end_idx]
if relevant_attns.numel() > 0:
layer_attentions.append(relevant_attns)
if not layer_attentions:
raise ValueError("Could not extract any valid attention weights. The generated text might be empty.")
avg_attentions = torch.stack(layer_attentions).mean(dim=0)
temperature = 0.1
weights = torch.nn.functional.softmax(avg_attentions / temperature, dim=1).cpu().numpy()
plt.figure(figsize=(10, 10))
plt.imshow(weights, aspect="auto", origin="lower", cmap="viridis")
plt.xlabel("Audio Frames")
plt.ylabel("Generated Text Tokens")
plt.title("Audio-to-Text Alignment Matrix (Sharpened)")
plt.colorbar()
plt.savefig("alignment_matrix.png")
print("Saved alignment matrix visualization to alignment_matrix.png")
# --- 4. DTW and Timestamp calculation ---
print("Performing DTW and mapping to timestamps...")
cost_matrix = -weights.T
D, wp = librosa.sequence.dtw(C=cost_matrix.astype(np.float32), backtrack=True)
wp = np.flip(wp, axis=0)
# *** THE FIX IS HERE ***
token_to_frame_map = {}
# Create a map of the first frame seen for each token index
for frame_idx, token_idx in wp:
if token_idx not in token_to_frame_map:
token_to_frame_map[token_idx] = frame_idx
word_groups = []
current_word_tokens = []
if generated_ids:
for token_id in generated_ids:
if tok.decode([token_id]).startswith(" ") and current_word_tokens:
word_groups.append(current_word_tokens)
current_word_tokens = []
current_word_tokens.append(token_id)
if current_word_tokens: word_groups.append(current_word_tokens)
EFFECTIVE_AUDIO_DURATION = 30.0
AUDIO_TIME_PER_FRAME = EFFECTIVE_AUDIO_DURATION / audio_output_frames
results = []
token_idx_counter = 0
# Initialize start time using the very first aligned token
previous_word_end_time = token_to_frame_map.get(0, 0) * AUDIO_TIME_PER_FRAME
for word_group in word_groups:
word_text = tok.decode(word_group).strip()
if not word_text: continue
start_time = previous_word_end_time
last_token_in_word_idx = token_idx_counter + len(word_group) - 1
end_frame = token_to_frame_map.get(last_token_in_word_idx, 0)
end_time = max(start_time, end_frame * AUDIO_TIME_PER_FRAME)
results.append({"word": word_text, "start": start_time, "end": end_time})
previous_word_end_time = end_time
token_idx_counter += len(word_group)
print("\n--- Transcription with Word-Level Timestamps ---")
for res in results: print(f"[{res['start']: >6.2f}s -> {res['end']: >6.2f}s] {res['word']}")
print("\n--- Timestamp Extraction Complete ---")
# ==============================================================================
# PART 4: VERIFICATION
# ==============================================================================
print("\n--- Verifying Alignment: Playing Audio Snippets ---")
SAMPLING_RATE = 16000
y, sr = librosa.load(audio_file_path, sr=SAMPLING_RATE)
if not results: print("No words were transcribed to verify.")
else:
for res in results:
start_sample = int(res['start'] * SAMPLING_RATE)
end_sample = int(res['end'] * SAMPLING_RATE)
audio_snippet = y[start_sample:end_sample]
print(f"\n[{res['start']: >6.2f}s -> {res['end']: >6.2f}s] {res['word']}")
if len(audio_snippet) > 0: display(Audio(audio_snippet, rate=SAMPLING_RATE))
else: print(" (No audio for this segment)")
print("\n--- Verification Complete ---")
Let me know if you have any issues. I'll upload some quantizations soon.
That's really great @urroxyz work! I'd also be interested in seeing how your export compares in performance to the un-timestamped one. Here's a comparison of the graphs:
We should be able to enable this in our optimized export too by adding an intermediate node as output. If you have some time to look into that, I'd greatly appreciate it!
I can definitely look into that.
Right now, I'm just attempting to quantize the model. I couldn't figure out how to do it via the CLI, so I'm tinkering around.
Using the ONNX library is a struggle for me.
I can definitely look into that.
The quickest way (i.e., without exporting) would be to use the onnx
library directly and modify the necessary nodes. I can try outline what I'd do if it can help you:
- Following the docs, https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftgroupqueryattention
qk_output : int
Output values of QK matrix multiplication before (1) or after (2) softmax normalization. Default value is 0 (don't output).
Modify each GroupQueryAttention node to use a value of 2 (assuming you want normalized scores) for qk_output
- You can then use the
output_qk
value
output_qk (optional) : T
Values of QK matrix multiplication, either before or after softmax normalization
- Add each of these output nodes to the list of model outputs,
graph.outputs
@Xenova
, I accidentally saved my ONNX conversion without the onnx_data
file.
Because of this, no one else will be able to run it. Trying to fix that now. (Edit: Fixed!)
If you could also create your own version (attentions/positions enabled, alongside quants), considering you're much more experienced in ONNX, I would greatly appreciate it. Or if you could check the code on the model README.md to see if I did well.