|
|
from collections.abc import Iterator |
|
|
from pathlib import Path |
|
|
from typing import Any |
|
|
|
|
|
import torch |
|
|
import transformers |
|
|
|
|
|
try: |
|
|
from .asr_modeling import ASRModel |
|
|
except ImportError: |
|
|
from asr_modeling import ASRModel |
|
|
|
|
|
|
|
|
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline): |
|
|
"""ASR Pipeline for audio-to-text transcription.""" |
|
|
|
|
|
model: ASRModel |
|
|
|
|
|
def __init__(self, model: ASRModel, **kwargs): |
|
|
feature_extractor = kwargs.pop("feature_extractor", None) |
|
|
tokenizer = kwargs.pop("tokenizer", model.tokenizer) |
|
|
|
|
|
|
|
|
if feature_extractor is None: |
|
|
processor = model.get_processor() |
|
|
feature_extractor = processor.feature_extractor |
|
|
|
|
|
super().__init__( |
|
|
model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs |
|
|
) |
|
|
|
|
|
def __call__(self, inputs, **kwargs): |
|
|
generate_kwargs = {} |
|
|
generate_keys = [ |
|
|
"max_new_tokens", |
|
|
"num_beams", |
|
|
"do_sample", |
|
|
"length_penalty", |
|
|
"repetition_penalty", |
|
|
"no_repeat_ngram_size", |
|
|
"early_stopping", |
|
|
"num_beam_groups", |
|
|
"diversity_penalty", |
|
|
"top_k", |
|
|
"temperature", |
|
|
"top_p", |
|
|
"user_prompt", |
|
|
"task", |
|
|
"text_input", |
|
|
] |
|
|
for key in generate_keys: |
|
|
if key in kwargs: |
|
|
generate_kwargs[key] = kwargs.pop(key) |
|
|
|
|
|
|
|
|
task = generate_kwargs.get("task") |
|
|
if task == "text" or generate_kwargs.get("text_input"): |
|
|
return self._process_text_only(generate_kwargs) |
|
|
|
|
|
|
|
|
if isinstance(inputs, list): |
|
|
return [self.__call__(inp, **kwargs, **generate_kwargs) for inp in inputs] |
|
|
|
|
|
model_inputs = self.preprocess(inputs, **kwargs) |
|
|
|
|
|
if isinstance(model_inputs, Iterator): |
|
|
return self._process_chunks(list(model_inputs), generate_kwargs) |
|
|
|
|
|
model_outputs = self._forward(model_inputs, **generate_kwargs) |
|
|
return self.postprocess(model_outputs) |
|
|
|
|
|
def _process_chunks(self, chunks: list, generate_kwargs: dict) -> dict[str, str]: |
|
|
"""Process chunked audio and merge results.""" |
|
|
all_tokens: list[int] = [] |
|
|
|
|
|
for chunk in chunks: |
|
|
output = self._forward(chunk, **generate_kwargs) |
|
|
tokens = output.get("tokens") |
|
|
if tokens is None: |
|
|
tokens = output.get("generated_ids") |
|
|
if tokens is not None: |
|
|
if torch.is_tensor(tokens): |
|
|
tokens = tokens.cpu() |
|
|
if len(tokens.shape) > 1: |
|
|
tokens = tokens[0] |
|
|
all_tokens.extend(tokens.tolist() if torch.is_tensor(tokens) else tokens) |
|
|
|
|
|
text = self.tokenizer.decode(all_tokens, skip_special_tokens=True).strip() |
|
|
|
|
|
return {"text": text} |
|
|
|
|
|
def preprocess(self, inputs, **preprocess_params): |
|
|
if isinstance(inputs, list): |
|
|
raise ValueError("Lists should not reach preprocess") |
|
|
|
|
|
preprocess_params.setdefault("chunk_length_s", 0) |
|
|
|
|
|
|
|
|
if isinstance(inputs, dict): |
|
|
if "bytes" in inputs: |
|
|
inputs = self._decode_audio_bytes(inputs["bytes"]) |
|
|
elif "array" in inputs: |
|
|
inputs = { |
|
|
"raw": inputs["array"], |
|
|
"sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate), |
|
|
} |
|
|
elif "path" in inputs and "array" not in inputs: |
|
|
|
|
|
inputs = self._decode_audio_bytes(Path(inputs["path"]).read_bytes()) |
|
|
elif hasattr(inputs, "array") and hasattr(inputs, "sampling_rate"): |
|
|
inputs = {"raw": inputs.array, "sampling_rate": inputs.sampling_rate} |
|
|
elif hasattr(inputs, "__array__") and not isinstance(inputs, (dict, bytes, str)): |
|
|
inputs = {"raw": inputs, "sampling_rate": self.model.config.audio_sample_rate} |
|
|
elif torch.is_tensor(inputs): |
|
|
inputs = { |
|
|
"raw": inputs.cpu().numpy(), |
|
|
"sampling_rate": self.model.config.audio_sample_rate, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(inputs, dict) and "sampling_rate" in inputs: |
|
|
in_sr = inputs["sampling_rate"] |
|
|
target_sr = self.feature_extractor.sampling_rate |
|
|
if in_sr != target_sr: |
|
|
import librosa |
|
|
import numpy as np |
|
|
|
|
|
audio = inputs["raw"] |
|
|
if hasattr(audio, "numpy"): |
|
|
audio = audio.numpy() |
|
|
resampled = librosa.resample( |
|
|
np.asarray(audio, dtype=np.float32), orig_sr=in_sr, target_sr=target_sr |
|
|
) |
|
|
inputs = {"raw": resampled, "sampling_rate": target_sr} |
|
|
|
|
|
return super().preprocess(inputs, **preprocess_params) |
|
|
|
|
|
def _decode_audio_bytes(self, wav_bytes: bytes) -> dict[str, Any]: |
|
|
"""Decode audio bytes to array format.""" |
|
|
import io |
|
|
|
|
|
import soundfile as sf |
|
|
|
|
|
audio_data, sample_rate = sf.read(io.BytesIO(wav_bytes)) |
|
|
return { |
|
|
"raw": audio_data, |
|
|
"sampling_rate": sample_rate, |
|
|
} |
|
|
|
|
|
def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]: |
|
|
task: str | None = generate_kwargs.pop("task", None) |
|
|
|
|
|
|
|
|
task_params: dict[str, dict[str, Any]] = { |
|
|
"transcribe": {"do_sample": False}, |
|
|
"emotion": {"do_sample": True, "temperature": 0.7}, |
|
|
"describe": {"do_sample": True, "temperature": 0.7}, |
|
|
"continue": {"do_sample": True, "temperature": 1.0}, |
|
|
} |
|
|
if task is not None and task in task_params: |
|
|
for key, value in task_params[task].items(): |
|
|
generate_kwargs.setdefault(key, value) |
|
|
|
|
|
|
|
|
audio_inputs, is_whisper = self._extract_audio(model_inputs) |
|
|
is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True |
|
|
|
|
|
|
|
|
generate_kwargs.setdefault( |
|
|
"eos_token_id", self.model.tokenizer.convert_tokens_to_ids("<|im_end|>") |
|
|
) |
|
|
generate_kwargs.setdefault("max_new_tokens", self.model.config.max_new_tokens) |
|
|
|
|
|
|
|
|
if is_whisper: |
|
|
generated_ids = self.model.generate( |
|
|
input_features=audio_inputs, |
|
|
task=task, |
|
|
**generate_kwargs, |
|
|
) |
|
|
else: |
|
|
generated_ids = self.model.generate( |
|
|
input_values=audio_inputs, |
|
|
task=task, |
|
|
**generate_kwargs, |
|
|
) |
|
|
|
|
|
return {"tokens": generated_ids, "is_last": is_last} |
|
|
|
|
|
def _extract_audio(self, model_inputs) -> tuple[torch.Tensor, bool]: |
|
|
"""Extract audio tensor from various input formats.""" |
|
|
if isinstance(model_inputs, torch.Tensor): |
|
|
return model_inputs.to(self.model.device), False |
|
|
|
|
|
if isinstance(model_inputs, (list, tuple)) and model_inputs: |
|
|
model_inputs = ( |
|
|
model_inputs[0] |
|
|
if isinstance(model_inputs[0], dict) |
|
|
else {"input_values": model_inputs[0]} |
|
|
) |
|
|
|
|
|
if isinstance(model_inputs, dict): |
|
|
model_inputs.pop("stride", None) |
|
|
if "input_features" in model_inputs: |
|
|
return model_inputs["input_features"].to(self.model.device), True |
|
|
if "input_values" in model_inputs: |
|
|
return model_inputs["input_values"].to(self.model.device), False |
|
|
|
|
|
raise ValueError(f"Could not extract audio from {type(model_inputs)}") |
|
|
|
|
|
def _process_text_only(self, generate_kwargs: dict) -> dict[str, str]: |
|
|
"""Process text-only input without audio.""" |
|
|
text_input = generate_kwargs.pop("text_input", None) |
|
|
if text_input is None: |
|
|
raise ValueError("text_input required for text task") |
|
|
|
|
|
generate_kwargs.pop("task", None) |
|
|
generated_ids = self.model.generate(task="text", text_input=text_input, **generate_kwargs) |
|
|
text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
|
|
|
return {"text": text} |
|
|
|
|
|
def postprocess( |
|
|
self, model_outputs: dict[str, Any], return_timestamps=None, return_language=None |
|
|
) -> dict[str, str]: |
|
|
if isinstance(model_outputs, list): |
|
|
for output in model_outputs: |
|
|
for key, value in output.items(): |
|
|
if torch.is_tensor(value): |
|
|
output[key] = value.cpu() |
|
|
return super().postprocess(model_outputs) |
|
|
|
|
|
model_outputs.pop("is_last", None) |
|
|
tokens = model_outputs.get("tokens") or model_outputs.get("generated_ids") |
|
|
|
|
|
if tokens is None: |
|
|
raise ValueError(f"Expected 'tokens' or 'generated_ids', got: {model_outputs.keys()}") |
|
|
|
|
|
if torch.is_tensor(tokens) and tokens.device.type != "cpu": |
|
|
tokens = tokens.cpu() |
|
|
if len(tokens.shape) > 1: |
|
|
tokens = tokens[0] |
|
|
|
|
|
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip() |
|
|
|
|
|
return {"text": text} |
|
|
|