"""Factory for creating STT provider instances.""" import logging from typing import Dict, Type, Optional from ..base.stt_provider_base import STTProviderBase from .whisper_provider import WhisperSTTProvider from .parakeet_provider import ParakeetSTTProvider from ...domain.exceptions import SpeechRecognitionException logger = logging.getLogger(__name__) class STTProviderFactory: """Factory for creating STT provider instances with availability checking and fallback logic.""" _providers: Dict[str, Type[STTProviderBase]] = { "whisper": WhisperSTTProvider, "parakeet": ParakeetSTTProvider } _fallback_order = ["whisper", "parakeet"] @classmethod def create_provider(cls, provider_name: str) -> STTProviderBase: """ Create an STT provider instance by name. Args: provider_name: Name of the provider to create Returns: STTProviderBase: The created provider instance Raises: SpeechRecognitionException: If provider is not available or creation fails """ provider_name = provider_name.lower() if provider_name not in cls._providers: raise SpeechRecognitionException(f"Unknown STT provider: {provider_name}") provider_class = cls._providers[provider_name] try: provider = provider_class() if not provider.is_available(): raise SpeechRecognitionException(f"STT provider {provider_name} is not available") logger.info(f"Created STT provider: {provider_name}") return provider except Exception as e: logger.error(f"Failed to create STT provider {provider_name}: {str(e)}") raise SpeechRecognitionException(f"Failed to create STT provider {provider_name}: {str(e)}") from e @classmethod def create_provider_with_fallback(cls, preferred_provider: str) -> STTProviderBase: """ Create an STT provider with fallback to other available providers. Args: preferred_provider: The preferred provider name Returns: STTProviderBase: The created provider instance Raises: SpeechRecognitionException: If no providers are available """ # Try preferred provider first try: return cls.create_provider(preferred_provider) except SpeechRecognitionException as e: logger.warning(f"Preferred STT provider {preferred_provider} failed: {str(e)}") # Try fallback providers for provider_name in cls._fallback_order: if provider_name.lower() == preferred_provider.lower(): continue # Skip the preferred provider we already tried try: logger.info(f"Trying fallback STT provider: {provider_name}") return cls.create_provider(provider_name) except SpeechRecognitionException as e: logger.warning(f"Fallback STT provider {provider_name} failed: {str(e)}") continue raise SpeechRecognitionException("No STT providers are available") @classmethod def get_available_providers(cls) -> list[str]: """ Get list of available STT providers. Returns: list[str]: List of available provider names """ available = [] for provider_name, provider_class in cls._providers.items(): try: provider = provider_class() if provider.is_available(): available.append(provider_name) except Exception as e: logger.info(f"Provider {provider_name} not available: {str(e)}") return available @classmethod def get_provider_info(cls, provider_name: str) -> Optional[dict]: """ Get information about a specific provider. Args: provider_name: Name of the provider Returns: Optional[dict]: Provider information or None if not found """ provider_name = provider_name.lower() if provider_name not in cls._providers: return None provider_class = cls._providers[provider_name] try: provider = provider_class() return { "name": provider.provider_name, "available": provider.is_available(), "supported_languages": provider.supported_languages, "available_models": provider.get_available_models() if provider.is_available() else [], "default_model": provider.get_default_model() if provider.is_available() else None } except Exception as e: logger.info(f"Failed to get info for provider {provider_name}: {str(e)}") return { "name": provider_name, "available": False, "error": str(e) } @classmethod def register_provider(cls, name: str, provider_class: Type[STTProviderBase]) -> None: """ Register a new STT provider. Args: name: Name of the provider provider_class: The provider class """ cls._providers[name.lower()] = provider_class logger.info(f"Registered STT provider: {name}") # Legacy compatibility - create an ASRFactory alias class ASRFactory: """Legacy ASRFactory for backward compatibility.""" @staticmethod def get_model(model_name: str = "parakeet") -> STTProviderBase: """ Get STT provider by model name (legacy interface). Args: model_name: Name of the model/provider to use Returns: STTProviderBase: The provider instance """ # Map legacy model names to provider names provider_mapping = { "whisper": "whisper", "parakeet": "parakeet", "faster-whisper": "whisper" } provider_name = provider_mapping.get(model_name.lower(), model_name.lower()) try: return STTProviderFactory.create_provider(provider_name) except SpeechRecognitionException: # Fallback to any available provider logger.warning(f"Requested provider {provider_name} not available, using fallback") return STTProviderFactory.create_provider_with_fallback(provider_name)