Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torchaudio | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| from transformers import AutoProcessor, AutoModelForCTC | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {device}") | |
| # load examples | |
| examples = [] | |
| examples_dir = "examples" | |
| if os.path.exists(examples_dir): | |
| for filename in os.listdir(examples_dir): | |
| if filename.endswith((".wav", ".mp3", ".ogg")): | |
| examples.append([os.path.join(examples_dir, filename)]) | |
| # Load model and processor | |
| MODEL_PATH = "badrex/w2v-bert-2.0-swahili-asr" | |
| processor = AutoProcessor.from_pretrained(MODEL_PATH) | |
| model = AutoModelForCTC.from_pretrained(MODEL_PATH) | |
| # move model and processor to device | |
| model = model.to(device) | |
| #processor = processor.to(device) | |
| def process_audio(audio_path): | |
| """Process audio with return the generated respotextnse. | |
| Args: | |
| audio_path: Path to the audio file to be transcribed. | |
| Returns: | |
| String containing the transcribed text from the audio file, or an error message | |
| if the audio file is missing. | |
| """ | |
| if not audio_path: | |
| return "Please upload an audio file." | |
| # get audio array | |
| audio_array, sample_rate = torchaudio.load(audio_path) | |
| # if sample rate is not 16000, resample to 16000 | |
| if sample_rate != 16000: | |
| audio_array = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio_array) | |
| #audio_array = audio_array.to(device) | |
| inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| #inputs = inputs.to(device, dtype=torch.bfloat16) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| outputs = torch.argmax(logits, dim=-1) | |
| decoded_outputs = processor.batch_decode( | |
| outputs, | |
| skip_special_tokens=True | |
| ) | |
| return decoded_outputs[0].strip() | |
| # Define Gradio interface | |
| with gr.Blocks(title="Voxtral Demo") as demo: | |
| gr.Markdown("# Swahili-ASR ποΈ Speech Recognition for Swahili Language π₯₯") | |
| #gr.Markdown("Developed with β€ by [Badr al-Absi](https://badrex.github.io/)") | |
| gr.Markdown( | |
| 'Developed with <span style="color:red;">β€</span> by <a href="https://badrex.github.io/">Badr al-Absi</a>' | |
| ) | |
| gr.Markdown( | |
| """### Hi there ππΌ | |
| This is a demo for [badrex/w2v-bert-2.0-swahili-asr](https://huggingface.co/badrex/w2v-bert-2.0-swahili-asr), | |
| a robust Transformer-based automatic speech recognition (ASR) system for Swahili language that was trained on 400+ hours of human-transcribed speech. | |
| """ | |
| ) | |
| gr.Markdown("Simply **upload an audio file** π€ or **record yourself speaking** ποΈβΊοΈ to try out the model!") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio(type="filepath", label="Upload Audio") | |
| submit_btn = gr.Button("Transcribe Audio", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Textbox(label="Text Transcription", lines=10) | |
| submit_btn.click( | |
| fn=process_audio, | |
| inputs=[audio_input], | |
| outputs=output_text | |
| ) | |
| gr.Examples( | |
| examples=examples if examples else None, | |
| inputs=[audio_input], | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.queue().launch() #share=False, ssr_mode=False, mcp_server=True | |
| # demo = gr.Interface( | |
| # fn=transcribe, | |
| # inputs=gr.Audio(), | |
| # outputs="text", | |
| # title="<div></div>", | |
| # description=""" | |
| # <div class="centered-content"> | |
| # <div> | |
| # <p> | |
| # Developed with β€ by <a href="https://badrex.github.io/" style="color: #2563eb;">Badr al-Absi</a> β | |
| # </p> | |
| # <br> | |
| # <p style="font-size: 15px; line-height: 1.8;"> | |
| # Hi there ππΌ | |
| # <br> | |
| # <br> | |
| # This is a demo for <a href="https://huggingface.co/badrex/w2v-bert-2.0-swahili-asr" style="color: #2563eb;"> badrex/w2v-bert-2.0-swahili-asr</a>, a robust Transformer-based automatic speech recognition (ASR) system for Swahili language. | |
| # The underlying ASR model was trained on more than 400 hours of transcribed speech. | |
| # <br> | |
| # <p style="font-size: 15px; line-height: 1.8;"> | |
| # Simply <strong>upload an audio file</strong> π€ or <strong>record yourself speaking</strong> ποΈβΊοΈ to try out the model! | |
| # </p> | |
| # </div> | |
| # </div> | |
| # """, | |
| # examples=examples if examples else None, | |
| # cache_examples=False, | |
| # flagging_mode=None, | |
| # ) | |
| # if __name__ == "__main__": | |
| # demo.launch() |