File size: 9,531 Bytes
d7aa21c 10b2bb7 d7aa21c 10b2bb7 d7aa21c 10b2bb7 d7aa21c 01ab6d7 d7aa21c 01ab6d7 d7aa21c 01ab6d7 d7aa21c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 |
from collections.abc import Iterator
from pathlib import Path
from typing import Any
import torch
import transformers
try:
from .asr_modeling import ASRModel
except ImportError:
from asr_modeling import ASRModel # type: ignore[no-redef]
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
"""ASR Pipeline for audio-to-text transcription."""
model: ASRModel
def __init__(self, model: ASRModel, **kwargs):
feature_extractor = kwargs.pop("feature_extractor", None)
tokenizer = kwargs.pop("tokenizer", model.tokenizer)
# Get feature extractor from model's processor if not provided
if feature_extractor is None:
processor = model.get_processor()
feature_extractor = processor.feature_extractor
super().__init__(
model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
)
def __call__(self, inputs, **kwargs):
generate_kwargs = {}
generate_keys = [
"max_new_tokens",
"num_beams",
"do_sample",
"length_penalty",
"repetition_penalty",
"no_repeat_ngram_size",
"early_stopping",
"num_beam_groups",
"diversity_penalty",
"top_k",
"temperature",
"top_p",
"user_prompt",
"task",
"text_input",
]
for key in generate_keys:
if key in kwargs:
generate_kwargs[key] = kwargs.pop(key)
# Handle text-only mode
task = generate_kwargs.get("task")
if task == "text" or generate_kwargs.get("text_input"):
return self._process_text_only(generate_kwargs)
# Handle list inputs
if isinstance(inputs, list):
return [self.__call__(inp, **kwargs, **generate_kwargs) for inp in inputs]
model_inputs = self.preprocess(inputs, **kwargs)
if isinstance(model_inputs, Iterator):
return self._process_chunks(list(model_inputs), generate_kwargs)
model_outputs = self._forward(model_inputs, **generate_kwargs)
return self.postprocess(model_outputs)
def _process_chunks(self, chunks: list, generate_kwargs: dict) -> dict[str, str]:
"""Process chunked audio and merge results."""
all_tokens: list[int] = []
for chunk in chunks:
output = self._forward(chunk, **generate_kwargs)
tokens = output.get("tokens")
if tokens is None:
tokens = output.get("generated_ids")
if tokens is not None:
if torch.is_tensor(tokens):
tokens = tokens.cpu()
if len(tokens.shape) > 1:
tokens = tokens[0]
all_tokens.extend(tokens.tolist() if torch.is_tensor(tokens) else tokens)
text = self.tokenizer.decode(all_tokens, skip_special_tokens=True).strip()
return {"text": text}
def preprocess(self, inputs, **preprocess_params):
if isinstance(inputs, list):
raise ValueError("Lists should not reach preprocess")
preprocess_params.setdefault("chunk_length_s", 0)
# Normalize input formats
if isinstance(inputs, dict):
if "bytes" in inputs:
inputs = self._decode_audio_bytes(inputs["bytes"])
elif "array" in inputs:
inputs = {
"raw": inputs["array"],
"sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
}
elif "path" in inputs and "array" not in inputs:
# Lazy-loaded audio - load from path
inputs = self._decode_audio_bytes(Path(inputs["path"]).read_bytes())
elif hasattr(inputs, "array") and hasattr(inputs, "sampling_rate"):
inputs = {"raw": inputs.array, "sampling_rate": inputs.sampling_rate}
elif hasattr(inputs, "__array__") and not isinstance(inputs, (dict, bytes, str)):
inputs = {"raw": inputs, "sampling_rate": self.model.config.audio_sample_rate}
elif torch.is_tensor(inputs):
inputs = {
"raw": inputs.cpu().numpy(),
"sampling_rate": self.model.config.audio_sample_rate,
}
# Resample to target sample rate if needed (workaround for transformers bug)
# See: https://github.com/huggingface/transformers/pull/41298
if isinstance(inputs, dict) and "sampling_rate" in inputs:
in_sr = inputs["sampling_rate"]
target_sr = self.feature_extractor.sampling_rate
if in_sr != target_sr:
import librosa
import numpy as np
audio = inputs["raw"]
if hasattr(audio, "numpy"):
audio = audio.numpy()
resampled = librosa.resample(
np.asarray(audio, dtype=np.float32), orig_sr=in_sr, target_sr=target_sr
)
inputs = {"raw": resampled, "sampling_rate": target_sr}
return super().preprocess(inputs, **preprocess_params)
def _decode_audio_bytes(self, wav_bytes: bytes) -> dict[str, Any]:
"""Decode audio bytes to array format."""
import io
import soundfile as sf
audio_data, sample_rate = sf.read(io.BytesIO(wav_bytes))
return {
"raw": audio_data,
"sampling_rate": sample_rate,
}
def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
task: str | None = generate_kwargs.pop("task", None)
# Task-specific defaults
task_params: dict[str, dict[str, Any]] = {
"transcribe": {"do_sample": False},
"emotion": {"do_sample": True, "temperature": 0.7},
"describe": {"do_sample": True, "temperature": 0.7},
"continue": {"do_sample": True, "temperature": 1.0},
}
if task is not None and task in task_params:
for key, value in task_params[task].items():
generate_kwargs.setdefault(key, value)
# Extract audio from model_inputs
audio_inputs, is_whisper = self._extract_audio(model_inputs)
is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
# Generation defaults
generate_kwargs.setdefault(
"eos_token_id", self.model.tokenizer.convert_tokens_to_ids("<|im_end|>")
)
generate_kwargs.setdefault("max_new_tokens", self.model.config.max_new_tokens)
# Generate
if is_whisper:
generated_ids = self.model.generate(
input_features=audio_inputs,
task=task,
**generate_kwargs,
)
else:
generated_ids = self.model.generate(
input_values=audio_inputs,
task=task,
**generate_kwargs,
)
return {"tokens": generated_ids, "is_last": is_last}
def _extract_audio(self, model_inputs) -> tuple[torch.Tensor, bool]:
"""Extract audio tensor from various input formats."""
if isinstance(model_inputs, torch.Tensor):
return model_inputs.to(self.model.device), False
if isinstance(model_inputs, (list, tuple)) and model_inputs:
model_inputs = (
model_inputs[0]
if isinstance(model_inputs[0], dict)
else {"input_values": model_inputs[0]}
)
if isinstance(model_inputs, dict):
model_inputs.pop("stride", None)
if "input_features" in model_inputs:
return model_inputs["input_features"].to(self.model.device), True
if "input_values" in model_inputs:
return model_inputs["input_values"].to(self.model.device), False
raise ValueError(f"Could not extract audio from {type(model_inputs)}")
def _process_text_only(self, generate_kwargs: dict) -> dict[str, str]:
"""Process text-only input without audio."""
text_input = generate_kwargs.pop("text_input", None)
if text_input is None:
raise ValueError("text_input required for text task")
generate_kwargs.pop("task", None)
generated_ids = self.model.generate(task="text", text_input=text_input, **generate_kwargs)
text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return {"text": text}
def postprocess(
self, model_outputs: dict[str, Any], return_timestamps=None, return_language=None
) -> dict[str, str]:
if isinstance(model_outputs, list):
for output in model_outputs:
for key, value in output.items():
if torch.is_tensor(value):
output[key] = value.cpu()
return super().postprocess(model_outputs)
model_outputs.pop("is_last", None)
tokens = model_outputs.get("tokens") or model_outputs.get("generated_ids")
if tokens is None:
raise ValueError(f"Expected 'tokens' or 'generated_ids', got: {model_outputs.keys()}")
if torch.is_tensor(tokens) and tokens.device.type != "cpu":
tokens = tokens.cpu()
if len(tokens.shape) > 1:
tokens = tokens[0]
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
return {"text": text}
|