Find more info here.

How to create this conversion

Use the script below to convert Voxtral to ONNX with attentions and positions exposed.

import torch
from torch import nn
from transformers import VoxtralForConditionalGeneration
from transformers.cache_utils import DynamicCache
import os
import onnx

model_id = "mistralai/Voxtral-Mini-3B-2507"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model = VoxtralForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=True,
    use_safetensors=True,
    attn_implementation="eager",
)
model.to(device)
model.eval()

class DecoderONNXWrapper(nn.Module):
    def __init__(self, language_model):
        super().__init__()
        self.language_model = language_model

    def forward(self, inputs_embeds, attention_mask, *past_key_value_tensors):
        num_layers = self.language_model.config.num_hidden_layers
        legacy_past = tuple(
            (past_key_value_tensors[i*2], past_key_value_tensors[i*2+1]) for i in range(num_layers)
        )
        past_key_values_cache = DynamicCache.from_legacy_cache(past_key_values=legacy_past)

        outputs = self.language_model(
            input_ids=None,
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            past_key_values=past_key_values_cache,
            output_attentions=True,
            use_cache=True,
        )

        flat_outputs = [outputs.logits]
        for k, v in zip(outputs.past_key_values.key_cache, outputs.past_key_values.value_cache):
            flat_outputs.extend([k, v])
        for attn in outputs.attentions:
            flat_outputs.append(attn)
        return tuple(flat_outputs)

batch_size = 1
seq_len = 128
past_seq_len = 100
text_config = model.config.text_config
num_layers = text_config.num_hidden_layers
hidden_size = text_config.hidden_size
head_dim = text_config.head_dim
num_kv_heads = text_config.num_key_value_heads

inputs_embeds = torch.randn((batch_size, seq_len, hidden_size), dtype=torch_dtype, device=device)
attention_mask_4d = torch.ones((batch_size, 1, seq_len, past_seq_len + seq_len), dtype=torch_dtype, device=device)
past_key_value_flat_tuple = tuple(
    torch.randn((batch_size, num_kv_heads, past_seq_len, head_dim), dtype=torch_dtype, device=device)
    for _ in range(num_layers * 2)
)
dummy_inputs = (inputs_embeds, attention_mask_4d) + past_key_value_flat_tuple

output_path = "decoder_model_attentive_unpacked.onnx"
input_names = ["inputs_embeds", "attention_mask"] + [f"past_key_values.{i}.{kv}" for i in range(num_layers) for kv in ["key", "value"]]
output_names = ["logits"] + [f"present.{i}.{kv}" for i in range(num_layers) for kv in ["key", "value"]] + [f"attention.{i}" for i in range(num_layers)]

dynamic_axes = {
    "inputs_embeds": {1: "sequence_length"},
    "attention_mask": {2: "sequence_length", 3: "total_sequence_length"},
}
for name in input_names + output_names:
    if "key" in name or "value" in name:
        dynamic_axes[name] = {2: "past_sequence_length"} if "past" in name else {2: "total_sequence_length"}
    elif "attention" in name:
        dynamic_axes[name] = {2: "sequence_length", 3: "total_sequence_length"}

wrapped_model = DecoderONNXWrapper(model.language_model)
wrapped_model.eval()

with torch.no_grad():
    torch.onnx.export(
        wrapped_model,
        dummy_inputs,
        output_path,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
        opset_version=17,
    )

import onnx

onnx_model = onnx.load(output_path, load_external_data=True)

data_file_location = "decoder_model_attentive.onnx_data"

onnx.save_model(
    onnx_model,
    "decoder_model_attentive.onnx",
    save_as_external_data=True,
    all_tensors_to_one_file=True,
    location=data_file_location,
)

Clean up the env:

for fname in os.listdir("."):
    if fname.startswith("language_") or fname.startswith("onnx_"):
        os.remove(os.path.join(my_dir, fname))

The script below will quantize a given ONNX file if its data is also supplied.

import os
from neural_compressor import PostTrainingQuantConfig, quantization
from neural_compressor.utils.constant import FP32
import onnx
import logging
logging.basicConfig(level=logging.INFO)

model_dir = "."
model_fp32 = 'decoder_model_attentive.onnx'
model_quantized = 'decoder_model_attentive_q4_weight_only_inc.onnx'

input_model_path = os.path.join(model_dir, model_fp32)
output_model_path = os.path.join(model_dir, model_quantized)

try:
    if not onnx.checker.check_model(input_model_path):
        print(f"Error: Original model '{input_model_path}' is not a valid ONNX model.")
        exit()
    print(f"Original model '{input_model_path}' is valid.")
except Exception as e:
    print(f"Failed to load or check original model '{input_model_path}': {e}")
    print("Please ensure the original model file exists and is not corrupted.")
    exit()

config = PostTrainingQuantConfig(
    approach="weight_only",
    op_type_dict={
        ".*": {
            "weight": {
                "bits": 4,
                "algorithm": ["RTN"],
                "scheme": ["asym"],
                "group_size": 32,
            }
        }
    },
)

print(f"\nAttempting to quantize '{model_fp32}' to 4-bit weight-only using Neural Compressor...")

try:
    q_model = quantization.fit(
        input_model_path,
        config,
    )

    q_model.save(output_model_path)
    print(f"Model successfully quantized and saved to {output_model_path}")

except Exception as e:
    print(f"Error during Neural Compressor weight-only quantization: {e}")
    print("Please ensure Neural Compressor is installed (`pip install neural_compressor`)")
    print("and that your ONNX Runtime version is compatible.")

Inference and extract word-level timestamps

  1. Download BOTH decoder files from this repo (*.onnx AND *.onnx_data)
  2. Install requirements, e.g., ipython
  3. Download audio.wav (any audio file you wish to transcribe)
  4. Then, run the following Python cell:
import os
import numpy as np
import onnxruntime as ort
from huggingface_hub import snapshot_download
from tokenizers import Tokenizer
import soundfile as sf
import librosa
import logging
import matplotlib.pyplot as plt
from IPython.display import display, Audio
import torch

def _create_4d_causal_attention_mask(input_shape, past_sequence_length, dtype=np.float32):
    batch_size, sequence_length = input_shape
    total_sequence_length = past_sequence_length + sequence_length

    mask = np.tril(np.ones((total_sequence_length, total_sequence_length), dtype=np.bool_))
    mask = mask[past_sequence_length:, :]

    causal_mask = np.zeros((batch_size, 1, sequence_length, total_sequence_length), dtype=dtype)
    causal_mask[:, :, :, :] = np.where(
        mask[None, None, :, :], 0.0, np.finfo(dtype).min
    )
    return causal_mask

repo_id = "onnx-community/Voxtral-Mini-3B-2507-ONNX"
audio_file_path = "audio.wav"
custom_decoder_path = "decoder_model_attentive_q4_weight_only_inc.onnx"
max_generation_tokens = 999
eos_token_id = 2

print(f"Downloading base model files from {repo_id}...")
local_dir = snapshot_download(
    repo_id=repo_id,
    repo_type="model",
    allow_patterns=["onnx/audio_encoder_q4.*", "onnx/embed_tokens_q4.*", "onnx/decoder_model_merged_q4.*", "tokenizer.json"],
)
onnx_dir = os.path.join(local_dir, "onnx")
tok = Tokenizer.from_file(os.path.join(local_dir, "tokenizer.json"))
bos_id, inst_id, baud_id, aud_id, einst_id = 1, 3, 25, 24, 4
ae_path = os.path.join(onnx_dir, "audio_encoder_q4.onnx")
embed_path = os.path.join(onnx_dir, "embed_tokens_q4.onnx")
if not os.path.exists(custom_decoder_path):
    raise FileNotFoundError(f"Custom ONNX decoder not found at '{custom_decoder_path}'.")
sess_opts = ort.SessionOptions()
session_providers = ["CPUExecutionProvider"]
ae_sess = ort.InferenceSession(ae_path, sess_options=sess_opts, providers=session_providers)
embed_sess = ort.InferenceSession(embed_path, sess_options=sess_opts, providers=session_providers)
dec_sess = ort.InferenceSession(custom_decoder_path, sess_options=sess_opts, providers=session_providers)
num_decoder_layers = sum(1 for i in dec_sess.get_inputs() if i.name.endswith(".key"))
print(f"Detected {num_decoder_layers} decoder layers.")

def extract_mel_features_for_chunk(audio_chunk, sampling_rate=16000, n_fft=400, hop_length=160, n_mels=128, target_length=3000):
    target_samples = sampling_rate * 30
    audio_chunk = librosa.util.fix_length(audio_chunk, size=target_samples)
    mel_spec = librosa.feature.melspectrogram(y=audio_chunk, sr=sampling_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
    log_spec = np.log10(np.maximum(mel_spec, 1e-10))
    log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
    log_spec = (log_spec + 4.0) / 4.0
    return log_spec[:, :target_length].astype(np.float32)

def process_long_audio(audio_path, session, sampling_rate=16000):
    y, sr = sf.read(audio_path)
    if y.ndim > 1: y = y.mean(axis=1)
    if sr != sampling_rate: y = librosa.resample(y, orig_sr=sr, target_sr=sampling_rate)
    chunk_samples = chunk_duration = 30 * sampling_rate
    num_chunks = int(np.ceil(len(y) / chunk_samples))
    all_audio_embeds = []
    print(f"Processing in {num_chunks} chunk(s)...")
    for i in range(num_chunks):
        chunk = y[i * chunk_samples:(i + 1) * chunk_samples]
        mel_features = extract_mel_features_for_chunk(chunk, sampling_rate)
        all_audio_embeds.append(session.run(None, {session.get_inputs()[0].name: mel_features[None, :]})[0])
    return np.concatenate(all_audio_embeds, axis=0)

if not os.path.exists(audio_file_path): raise FileNotFoundError(f"Audio file '{audio_file_path}' not found.")

audio_embeds_raw = process_long_audio(audio_file_path, ae_sess)
batch_size = 1
audio_output_frames = audio_embeds_raw.shape[0] // batch_size
audio_embeds = audio_embeds_raw.reshape(batch_size, audio_output_frames, -1)
text_instruction_ids = tok.encode("Transcribe.", add_special_tokens=False).ids
prompt_tokens = ([bos_id, inst_id, baud_id] + [aud_id] * audio_output_frames + text_instruction_ids + [einst_id])
initial_sequence_length = len(prompt_tokens)

prompt_ids = np.array([prompt_tokens], dtype=np.int64)
inputs_embeds = embed_sess.run(None, {"input_ids": prompt_ids})[0]
inputs_embeds[0, 3:3 + audio_output_frames, :] = audio_embeds[0]
inputs_embeds = inputs_embeds.astype(np.float32)

generated_ids = []
past_key_values = None
current_past_len = 0
for i in range(max_generation_tokens):
    dec_inputs = {}
    if i == 0:
        dec_inputs["inputs_embeds"] = inputs_embeds
        attention_mask = _create_4d_causal_attention_mask((batch_size, initial_sequence_length), 0)
    else:
        last_token_id = np.array([[generated_ids[-1]]], dtype=np.int64)
        dec_inputs["inputs_embeds"] = embed_sess.run(None, {"input_ids": last_token_id})[0].astype(np.float32)
        attention_mask = _create_4d_causal_attention_mask((batch_size, 1), current_past_len)

    dec_inputs["attention_mask"] = attention_mask
    if past_key_values:
        for l in range(num_decoder_layers):
            dec_inputs[f"past_key_values.{l}.key"] = past_key_values[l*2].astype(np.float32)
            dec_inputs[f"past_key_values.{l}.value"] = past_key_values[l*2+1].astype(np.float32)
    else:
        for l in range(num_decoder_layers):
            dec_inputs[f"past_key_values.{l}.key"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float32)
            dec_inputs[f"past_key_values.{l}.value"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float32)

    outputs = dec_sess.run(None, dec_inputs)
    logits, past_key_values = outputs[0], outputs[1:1+num_decoder_layers*2]

    next_token_id = np.argmax(logits[0, -1, :])
    if next_token_id == eos_token_id: break
    generated_ids.append(next_token_id)
    print(tok.decode([next_token_id]), end="", flush=True)

full_sequence_ids = np.array([prompt_tokens + generated_ids], dtype=np.int64)
full_embeds = embed_sess.run(None, {"input_ids": full_sequence_ids})[0]
full_embeds[0, 3:3 + audio_output_frames, :] = audio_embeds[0]
full_embeds = full_embeds.astype(np.float32)

alignment_inputs = {
    "inputs_embeds": full_embeds,
    "attention_mask": _create_4d_causal_attention_mask(full_embeds.shape[:2], 0)
}
for l in range(num_decoder_layers):
    alignment_inputs[f"past_key_values.{l}.key"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float32)
    alignment_inputs[f"past_key_values.{l}.value"] = np.zeros((batch_size, 8, 0, 128), dtype=np.float32)

alignment_outputs = dec_sess.run(None, alignment_inputs)
attentions = [torch.from_numpy(attn) for attn in alignment_outputs[1+num_decoder_layers*2:]]

text_start_idx = len(prompt_tokens)
audio_end_idx = 3 + audio_output_frames
start_layer, end_layer = 10, 20
layer_attentions = []
for i in range(start_layer, end_layer):
    layer_attn = attentions[i][0]
    layer_attn_avg_heads = layer_attn.mean(dim=0)
    relevant_attns = layer_attn_avg_heads[text_start_idx:, 3:audio_end_idx]
    if relevant_attns.numel() > 0:
        layer_attentions.append(relevant_attns)

if not layer_attentions:
    raise ValueError("Could not extract any valid attention weights. The generated text might be empty.")

avg_attentions = torch.stack(layer_attentions).mean(dim=0)
temperature = 0.1
weights = torch.nn.functional.softmax(avg_attentions / temperature, dim=1).cpu().numpy()

plt.figure(figsize=(10, 10))
plt.imshow(weights, aspect="auto", origin="lower", cmap="viridis")
plt.xlabel("Audio Frames")
plt.ylabel("Generated Text Tokens")
plt.title("Audio-to-Text Alignment Matrix (Sharpened)")
plt.colorbar()
plt.savefig("alignment_matrix.png")

cost_matrix = -weights.T
D, wp = librosa.sequence.dtw(C=cost_matrix.astype(np.float32), backtrack=True)
wp = np.flip(wp, axis=0)

token_to_frame_map = {}
for frame_idx, token_idx in wp:
    if token_idx not in token_to_frame_map:
        token_to_frame_map[token_idx] = frame_idx

word_groups = []
current_word_tokens = []
if generated_ids:
    for token_id in generated_ids:
        if tok.decode([token_id]).startswith(" ") and current_word_tokens:
            word_groups.append(current_word_tokens)
            current_word_tokens = []
        current_word_tokens.append(token_id)
    if current_word_tokens: word_groups.append(current_word_tokens)

EFFECTIVE_AUDIO_DURATION = 30.0
AUDIO_TIME_PER_FRAME = EFFECTIVE_AUDIO_DURATION / audio_output_frames
results = []
token_idx_counter = 0
previous_word_end_time = token_to_frame_map.get(0, 0) * AUDIO_TIME_PER_FRAME

for word_group in word_groups:
    word_text = tok.decode(word_group).strip()
    if not word_text: continue

    start_time = previous_word_end_time
    last_token_in_word_idx = token_idx_counter + len(word_group) - 1
    end_frame = token_to_frame_map.get(last_token_in_word_idx, 0)
    end_time = max(start_time, end_frame * AUDIO_TIME_PER_FRAME)
    results.append({"word": word_text, "start": start_time, "end": end_time})
    previous_word_end_time = end_time
    token_idx_counter += len(word_group)

for res in results: print(f"[{res['start']: >6.2f}s -> {res['end']: >6.2f}s] {res['word']}")

SAMPLING_RATE = 16000
y, sr = librosa.load(audio_file_path, sr=SAMPLING_RATE)
if not results: print("No words were transcribed to verify.")
else:
    for res in results:
        start_sample = int(res['start'] * SAMPLING_RATE)
        end_sample = int(res['end'] * SAMPLING_RATE)
        audio_snippet = y[start_sample:end_sample]
        print(f"\n[{res['start']: >6.2f}s -> {res['end']: >6.2f}s] {res['word']}")
        if len(audio_snippet) > 0: display(Audio(audio_snippet, rate=SAMPLING_RATE))
        else: print("   (No audio for this segment)")
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for urroxyz/Voxtral-Mini-3B-2507_timestamped

Quantized
(6)
this model

Collection including urroxyz/Voxtral-Mini-3B-2507_timestamped