"""Whisper STT provider implementation.""" import logging from pathlib import Path from typing import TYPE_CHECKING if TYPE_CHECKING: from ...domain.models.audio_content import AudioContent from ...domain.models.text_content import TextContent from ..base.stt_provider_base import STTProviderBase from ...domain.exceptions import SpeechRecognitionException logger = logging.getLogger(__name__) class WhisperSTTProvider(STTProviderBase): """Whisper STT provider using faster-whisper implementation.""" def __init__(self): """Initialize the Whisper STT provider.""" super().__init__( provider_name="Whisper", supported_languages=["en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"] ) self.model = None self._device = None self._compute_type = None self._initialize_device_settings() def _initialize_device_settings(self): """Initialize device and compute type settings.""" try: import torch self._device = "cuda" if torch.cuda.is_available() else "cpu" except ImportError: # Fallback to CPU if torch is not available self._device = "cpu" self._compute_type = "float16" if self._device == "cuda" else "int8" logger.info(f"Whisper provider initialized with device: {self._device}, compute_type: {self._compute_type}") def _perform_transcription(self, audio_path: Path, model: str) -> str: """ Perform transcription using Faster Whisper. Args: audio_path: Path to the preprocessed audio file model: The Whisper model to use (e.g., 'large-v3', 'medium', 'small') Returns: str: The transcribed text """ try: # Load model if not already loaded or if model changed if self.model is None or getattr(self.model, 'model_size_or_path', None) != model: self._load_model(model) logger.info(f"Starting Whisper transcription with model {model}") # Perform transcription segments, info = self.model.transcribe( str(audio_path), beam_size=5, language="en", # Can be made configurable task="transcribe" ) logger.info(f"Detected language '{info.language}' with probability {info.language_probability}") # Collect all segments into a single text result_text = "" for segment in segments: result_text += segment.text + " " logger.info(f"[{segment.start:.2f}s -> {segment.end:.2f}s] {segment.text}") result = result_text.strip() logger.info("Whisper transcription completed successfully") return result except Exception as e: self._handle_provider_error(e, "transcription") def _load_model(self, model_name: str): """ Load the Whisper model. Args: model_name: Name of the model to load """ try: from faster_whisper import WhisperModel as FasterWhisperModel logger.info(f"Loading Whisper model: {model_name}") logger.info(f"Using device: {self._device}, compute_type: {self._compute_type}") self.model = FasterWhisperModel( model_name, device=self._device, compute_type=self._compute_type ) logger.info(f"Whisper model {model_name} loaded successfully") except ImportError as e: raise SpeechRecognitionException( "faster-whisper not available. Please install with: pip install faster-whisper" ) from e except Exception as e: raise SpeechRecognitionException(f"Failed to load Whisper model {model_name}: {str(e)}") from e def is_available(self) -> bool: """ Check if the Whisper provider is available. Returns: bool: True if faster-whisper is available, False otherwise """ try: import faster_whisper return True except ImportError: logger.warning("faster-whisper not available") return False def get_available_models(self) -> list[str]: """ Get list of available Whisper models. Returns: list[str]: List of available model names """ return [ "tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v1", "large-v2", "large-v3" ] def get_default_model(self) -> str: """ Get the default model for this provider. Returns: str: Default model name """ return "large-v3"