Spaces:
Sleeping
Sleeping
| # stt.py | |
| import os | |
| import torch | |
| import torchaudio | |
| import spaces | |
| import numpy as np | |
| from typing import Tuple | |
| from numpy.typing import NDArray | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| import tempfile | |
| # Create directories | |
| os.makedirs("transcriptions", exist_ok=True) | |
| # Initialize global models | |
| whisper_model = None | |
| whisper_processor = None | |
| # Model configurations | |
| WHISPER_MODEL_SIZES = { | |
| 'tiny': 'openai/whisper-tiny', | |
| 'base': 'openai/whisper-base', | |
| 'small': 'openai/whisper-small', | |
| 'medium': 'openai/whisper-medium', | |
| 'large': 'openai/whisper-large-v3', | |
| } | |
| class WhisperSTTModel: | |
| def __init__(self, model_size="base", language="en"): | |
| self.model_size = model_size | |
| self.language = language | |
| self._initialize_model() | |
| def _initialize_model(self): | |
| global whisper_model, whisper_processor | |
| # Get model identifier | |
| model_id = WHISPER_MODEL_SIZES.get(self.model_size.lower(), WHISPER_MODEL_SIZES['base']) | |
| # Load model and processor if not already loaded | |
| if whisper_model is None or whisper_processor is None or (whisper_model and whisper_model.config._name_or_path != model_id): | |
| print(f"Loading Whisper {self.model_size} model...") | |
| whisper_processor = WhisperProcessor.from_pretrained(model_id) | |
| whisper_model = WhisperForConditionalGeneration.from_pretrained(model_id) | |
| print(f"Model loaded on device: {whisper_model.device}") | |
| def stt(self, audio: Tuple[int, NDArray[np.float32]]) -> str: | |
| """Transcribe audio to text following the STTModel protocol""" | |
| sample_rate, audio_array = audio | |
| try: | |
| # Convert to mono if needed | |
| if len(audio_array.shape) > 1 and audio_array.shape[0] > 1: | |
| audio_array = np.mean(audio_array, axis=0) | |
| # Convert numpy array to torch tensor | |
| speech_array = torch.tensor(audio_array).unsqueeze(0) | |
| # Resample to 16kHz if needed | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
| speech_array = resampler(speech_array) | |
| # Prepare inputs for the model | |
| input_features = whisper_processor( | |
| speech_array.squeeze().numpy(), | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ).input_features | |
| # Generate transcription | |
| generation_kwargs = {} | |
| if self.language: | |
| forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language=self.language, task="transcribe") | |
| generation_kwargs["forced_decoder_ids"] = forced_decoder_ids | |
| # Run the model | |
| with torch.no_grad(): | |
| predicted_ids = whisper_model.generate(input_features, **generation_kwargs) | |
| # Decode the output | |
| transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True) | |
| # Return the transcribed text | |
| return transcription[0] | |
| except Exception as e: | |
| print(f"Error during transcription: {str(e)}") | |
| return "" | |
| # Create a singleton instance for easy import | |
| whisper_stt = WhisperSTTModel(model_size="base", language="en") | |
| # Legacy function for backward compatibility | |
| async def transcribe_audio(audio_file_path, model_size="base", language="en"): | |
| """For compatibility with older code""" | |
| # Load audio from file | |
| speech_array, sample_rate = torchaudio.load(audio_file_path) | |
| # Use the new model to transcribe | |
| return whisper_stt.stt((sample_rate, speech_array.squeeze().numpy())) |