SoybeanMilk/Breeze-ASR-26-quantized.w4a16
This model is a 4-bit quantized version of MediaTek-Research/Breeze-ASR-26, optimized for high-throughput inference using vLLM and Marlin acceleration kernels.
Model Details
- Base Model: MediaTek-Research/Breeze-ASR-26
- Quantization Method: GPTQ (W4A16)
- Quantization Tool:
llm-compressor - Calibration Dataset:
MLCommons/peoples_speech(clean subset, 256 samples) - Configuration:
- Weights: 4-bit
- Activations: 16-bit (FP16/BF16)
- Block Size (Group Size): 128
- Dampening Fraction: 0.01
Quantization Workflow
The quantization of this model involved a rigorous engineering process to overcome several architectural hurdles in the llm-compressor framework regarding multi-modal Encoder-Decoder models (Whisper).
1. Environment Setup
A clean Conda environment was established with Python 3.11. Specific version locking was required to avoid library conflicts:
vllm >= 0.7.3: For inference and Marlin kernel support.llmcompressor >= 0.4.0: For the core GPTQ quantization logic.datasets < 3.0.0: To avoid mandatorytorchcodecdependencies which conflict with CUDA-enabled PyTorch symbols in certain Linux/WSL environments.numpy < 2.3: To maintain compatibility withlibrosaandnumba.
2. Overcoming Engineering Challenges
During the quantization of a multi-modal Whisper model, several critical issues were addressed via advanced Monkey Patching:
- Multimodal Auto-loading Bug: The
oneshotAPI inllm-compressoroccasionally misidentifies Whisper as aCausalLM, leading to incomplete model loading. This was solved by manually instantiatingWhisperForConditionalGenerationbefore passing it to the compressor. - Tracer Missing Decoder Inputs: The underlying
torch.fxtracer often stripsdecoder_input_idsduring subgraph analysis. We implemented a Monkey Patch on the model'sforwardmethod to inject standard Traditional Chinese transcription labels ([<|startoftranscript|>, <|zh|>, <|transcribe|>, <|notimestamps|>]) if they are lost during the tracing phase. - Data Collator Filtering: Default collators in the framework filter out non-text features (like
input_features). A customDataCollatorand a patchedget_calibration_dataloaderwere used to ensure both audio features and decoder labels were preserved during calibration.
3. Execution & Technical Details
The quantization was performed using the oneshot API with the following technical specifics:
- Algorithm: GPTQ (Optimal Brain Quantization) was used to minimize the mean squared error (MSE) between the FP16 and 4-bit weights by calculating the inverse Hessian matrix for each layer.
- Layer Targeting: All
Linearlayers within both the WhisperEncoder and WhisperDecoder were targeted. Thelm_head(final projection layer) was explicitly ignored (ignore=["lm_head"]) to prevent performance degradation in the final token selection. - Weight Packing: The weights are stored in the
compressed-tensorsformat, which vLLM interprets to launch Marlin Kernels. Marlin is a high-performance 4-bit quantization kernel designed for NVIDIA GPUs (sm_80+), providing near-theoretical maximum memory bandwidth utilization. - Audio Preprocessing: Calibration samples were resampled to 16,000 Hz. Features were extracted using the
WhisperProcessorto generate aninput_featurestensor of shape[80, 3000]. - Calibration Precision: We used 256 clean audio samples. Increasing the sample count beyond this showed diminishing returns in terms of perplexity while significantly increasing the quantization time.
Evaluation and Validation
The model was validated using the test_mayday.mp3 sample.
- Inference Latency: The W4A16 version demonstrates a significant reduction in Time-To-First-Token (TTFT) and increased Inter-Token Latency (ITL) throughput compared to the original FP16 model when deployed on RTX 40-series hardware.
- VRAM Savings: The model footprint is reduced from ~3.1GB to ~0.9GB, allowing deployment on 8GB VRAM consumer-grade GPUs with ample space remaining for KV cache.
This model is specifically optimized for vLLM. It automatically triggers the Marlin acceleration kernel, providing significant speedups over FP16.
from vllm import LLM, SamplingParams
import librosa
# Load audio (standard 16kHz)
audio, sr = librosa.load("your_audio.mp3", sr=16000)
# Initialize engine
llm = LLM(
model="Breeze-ASR-26-quantized.w4a16",
max_model_len=448,
limit_mm_per_prompt={"audio": 1}
)
# Run inference
prompts = {
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {"audio": (audio, sr)}
},
"decoder_prompt": "<|startoftranscript|><|zh|><|transcribe|><|notimestamps|>",
}
outputs = llm.generate(prompts, SamplingParams(temperature=0, max_tokens=256))
print(outputs[0].outputs[0].text)
Reproduction Script
The complete logic for reproducing this quantization (including all necessary patches for the llm-compressor Whisper bug) is provided below:
import torch
from llmcompressor.modifiers.quantization import GPTQModifier
from llmcompressor import oneshot
from transformers import AutoProcessor, WhisperForConditionalGeneration
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader
import llmcompressor.entrypoints.oneshot
MODEL_ID = "Breeze-ASR-26" # Local directory
# Instantiate Processor and the full Whisper Multimodal Model
processor = AutoProcessor.from_pretrained(MODEL_ID)
print("Loading full Whisper model to avoid oneshot CausalLM auto-loading bug...")
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID, torch_dtype="auto", low_cpu_mem_usage=True)
# 1. Prepare Calibration Dataset
def preprocess_fn(batch):
audio_data = batch["audio"]["array"]
text = batch["text"]
input_features = processor(
audio=audio_data,
sampling_rate=16000,
return_tensors="pt"
).input_features
decoder_input_ids = processor.tokenizer(
text,
return_tensors="pt"
).input_ids
return {
"input_features": input_features.squeeze(0).tolist(),
"decoder_input_ids": decoder_input_ids.squeeze(0).tolist()
}
print("Loading dataset...")
dataset_stream = load_dataset("MLCommons/peoples_speech", "clean", split="train", streaming=True)
dataset_stream = dataset_stream.take(256)
dataset_stream = dataset_stream.map(preprocess_fn)
print("Converting streaming dataset to local dataset...")
data_list = []
for item in dataset_stream:
data_list.append({
"input_features": item["input_features"],
"decoder_input_ids": item["decoder_input_ids"]
})
my_dataset = Dataset.from_list(data_list)
def custom_data_collator(features):
return {
"input_features": torch.tensor([f["input_features"] for f in features]),
"decoder_input_ids": torch.tensor([f["decoder_input_ids"] for f in features])
}
# Monkey Patch: Bypass llm-compressor's text-only filtering mechanism
def mock_get_calibration_dataloader(*args, **kwargs):
return DataLoader(my_dataset, batch_size=1, collate_fn=custom_data_collator)
llmcompressor.entrypoints.oneshot.get_calibration_dataloader = mock_get_calibration_dataloader
# 2. Configure Quantization Recipe
recipe = GPTQModifier(
targets="Linear",
scheme="W4A16",
block_size=128,
dampening_frac=0.01,
sequential_targets=["WhisperEncoderLayer", "WhisperDecoderLayer"],
ignore=["lm_head"]
)
# 3. Run Quantization
print("Starting quantization...")
oneshot(
model=model,
dataset=my_dataset,
recipe=recipe,
output_dir="Breeze-ASR-26-quantized.w4a16",
max_seq_length=448,
)
print("Quantization completed successfully!")
Quantized and optimized by SoybeanMilk.
- Downloads last month
- 29
Model tree for SoybeanMilk/Breeze-ASR-26-quantized.w4a16
Base model
openai/whisper-large-v2