SoybeanMilk/Breeze-ASR-26-quantized.w4a16

This model is a 4-bit quantized version of MediaTek-Research/Breeze-ASR-26, optimized for high-throughput inference using vLLM and Marlin acceleration kernels.

Model Details

  • Base Model: MediaTek-Research/Breeze-ASR-26
  • Quantization Method: GPTQ (W4A16)
  • Quantization Tool: llm-compressor
  • Calibration Dataset: MLCommons/peoples_speech (clean subset, 256 samples)
  • Configuration:
    • Weights: 4-bit
    • Activations: 16-bit (FP16/BF16)
    • Block Size (Group Size): 128
    • Dampening Fraction: 0.01

Quantization Workflow

The quantization of this model involved a rigorous engineering process to overcome several architectural hurdles in the llm-compressor framework regarding multi-modal Encoder-Decoder models (Whisper).

1. Environment Setup

A clean Conda environment was established with Python 3.11. Specific version locking was required to avoid library conflicts:

  • vllm >= 0.7.3: For inference and Marlin kernel support.
  • llmcompressor >= 0.4.0: For the core GPTQ quantization logic.
  • datasets < 3.0.0: To avoid mandatory torchcodec dependencies which conflict with CUDA-enabled PyTorch symbols in certain Linux/WSL environments.
  • numpy < 2.3: To maintain compatibility with librosa and numba.

2. Overcoming Engineering Challenges

During the quantization of a multi-modal Whisper model, several critical issues were addressed via advanced Monkey Patching:

  • Multimodal Auto-loading Bug: The oneshot API in llm-compressor occasionally misidentifies Whisper as a CausalLM, leading to incomplete model loading. This was solved by manually instantiating WhisperForConditionalGeneration before passing it to the compressor.
  • Tracer Missing Decoder Inputs: The underlying torch.fx tracer often strips decoder_input_ids during subgraph analysis. We implemented a Monkey Patch on the model's forward method to inject standard Traditional Chinese transcription labels ([<|startoftranscript|>, <|zh|>, <|transcribe|>, <|notimestamps|>]) if they are lost during the tracing phase.
  • Data Collator Filtering: Default collators in the framework filter out non-text features (like input_features). A custom DataCollator and a patched get_calibration_dataloader were used to ensure both audio features and decoder labels were preserved during calibration.

3. Execution & Technical Details

The quantization was performed using the oneshot API with the following technical specifics:

  • Algorithm: GPTQ (Optimal Brain Quantization) was used to minimize the mean squared error (MSE) between the FP16 and 4-bit weights by calculating the inverse Hessian matrix for each layer.
  • Layer Targeting: All Linear layers within both the WhisperEncoder and WhisperDecoder were targeted. The lm_head (final projection layer) was explicitly ignored (ignore=["lm_head"]) to prevent performance degradation in the final token selection.
  • Weight Packing: The weights are stored in the compressed-tensors format, which vLLM interprets to launch Marlin Kernels. Marlin is a high-performance 4-bit quantization kernel designed for NVIDIA GPUs (sm_80+), providing near-theoretical maximum memory bandwidth utilization.
  • Audio Preprocessing: Calibration samples were resampled to 16,000 Hz. Features were extracted using the WhisperProcessor to generate an input_features tensor of shape [80, 3000].
  • Calibration Precision: We used 256 clean audio samples. Increasing the sample count beyond this showed diminishing returns in terms of perplexity while significantly increasing the quantization time.

Evaluation and Validation

The model was validated using the test_mayday.mp3 sample.

  • Inference Latency: The W4A16 version demonstrates a significant reduction in Time-To-First-Token (TTFT) and increased Inter-Token Latency (ITL) throughput compared to the original FP16 model when deployed on RTX 40-series hardware.
  • VRAM Savings: The model footprint is reduced from ~3.1GB to ~0.9GB, allowing deployment on 8GB VRAM consumer-grade GPUs with ample space remaining for KV cache.

This model is specifically optimized for vLLM. It automatically triggers the Marlin acceleration kernel, providing significant speedups over FP16.

from vllm import LLM, SamplingParams
import librosa

# Load audio (standard 16kHz)
audio, sr = librosa.load("your_audio.mp3", sr=16000)

# Initialize engine
llm = LLM(
    model="Breeze-ASR-26-quantized.w4a16",
    max_model_len=448,
    limit_mm_per_prompt={"audio": 1}
)

# Run inference
prompts = {
    "encoder_prompt": {
        "prompt": "",
        "multi_modal_data": {"audio": (audio, sr)}
    },
    "decoder_prompt": "<|startoftranscript|><|zh|><|transcribe|><|notimestamps|>",
}

outputs = llm.generate(prompts, SamplingParams(temperature=0, max_tokens=256))
print(outputs[0].outputs[0].text)

Reproduction Script

The complete logic for reproducing this quantization (including all necessary patches for the llm-compressor Whisper bug) is provided below:

import torch
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor import oneshot
from transformers import AutoProcessor, WhisperForConditionalGeneration
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader
import llmcompressor.entrypoints.oneshot

MODEL_ID = "Breeze-ASR-26" # Local directory

# Instantiate Processor and the full Whisper Multimodal Model
processor = AutoProcessor.from_pretrained(MODEL_ID)
print("Loading full Whisper model to avoid oneshot CausalLM auto-loading bug...")
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID, torch_dtype="auto", low_cpu_mem_usage=True)

# 1. Prepare Calibration Dataset
def preprocess_fn(batch):
    audio_data = batch["audio"]["array"]
    text = batch["text"]
    
    input_features = processor(
        audio=audio_data, 
        sampling_rate=16000, 
        return_tensors="pt"
    ).input_features
    
    decoder_input_ids = processor.tokenizer(
        text, 
        return_tensors="pt"
    ).input_ids
    
    return {
        "input_features": input_features.squeeze(0).tolist(),
        "decoder_input_ids": decoder_input_ids.squeeze(0).tolist()
    }

print("Loading dataset...")
dataset_stream = load_dataset("MLCommons/peoples_speech", "clean", split="train", streaming=True)
dataset_stream = dataset_stream.take(256)
dataset_stream = dataset_stream.map(preprocess_fn)

print("Converting streaming dataset to local dataset...")
data_list = []
for item in dataset_stream:
    data_list.append({
        "input_features": item["input_features"],
        "decoder_input_ids": item["decoder_input_ids"]
    })
my_dataset = Dataset.from_list(data_list)

def custom_data_collator(features):
    return {
        "input_features": torch.tensor([f["input_features"] for f in features]),
        "decoder_input_ids": torch.tensor([f["decoder_input_ids"] for f in features])
    }

# Monkey Patch: Bypass llm-compressor's text-only filtering mechanism
def mock_get_calibration_dataloader(*args, **kwargs):
    return DataLoader(my_dataset, batch_size=1, collate_fn=custom_data_collator)
llmcompressor.entrypoints.oneshot.get_calibration_dataloader = mock_get_calibration_dataloader

# 2. Configure Quantization Recipe
recipe = GPTQModifier(
    targets="Linear",
    scheme="W4A16",
    block_size=128,
    dampening_frac=0.01,
    sequential_targets=["WhisperEncoderLayer", "WhisperDecoderLayer"],
    ignore=["lm_head"]
)

# 3. Run Quantization
print("Starting quantization...")
oneshot(
    model=model, 
    dataset=my_dataset,
    recipe=recipe,
    output_dir="Breeze-ASR-26-quantized.w4a16",
    max_seq_length=448,
)
print("Quantization completed successfully!")

Quantized and optimized by SoybeanMilk.

Downloads last month
29
Safetensors
Model size
2B params
Tensor type
I64
F32
I32
Inference Providers NEW
This model isn't deployed by any Inference Provider. 馃檵 Ask for provider support

Model tree for SoybeanMilk/Breeze-ASR-26-quantized.w4a16

Quantized
(1)
this model