Spaces:
badrex
/
Running on Zero

Swahili-ASR / app.py
badrex's picture
Update app.py
d982d70 verified
import os
import torchaudio
import gradio as gr
import spaces
import torch
from transformers import AutoProcessor, AutoModelForCTC
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# load examples
examples = []
examples_dir = "examples"
if os.path.exists(examples_dir):
for filename in os.listdir(examples_dir):
if filename.endswith((".wav", ".mp3", ".ogg")):
examples.append([os.path.join(examples_dir, filename)])
# Load model and processor
MODEL_PATH = "badrex/w2v-bert-2.0-swahili-asr"
processor = AutoProcessor.from_pretrained(MODEL_PATH)
model = AutoModelForCTC.from_pretrained(MODEL_PATH)
# move model and processor to device
model = model.to(device)
#processor = processor.to(device)
@spaces.GPU()
def process_audio(audio_path):
"""Process audio with return the generated respotextnse.
Args:
audio_path: Path to the audio file to be transcribed.
Returns:
String containing the transcribed text from the audio file, or an error message
if the audio file is missing.
"""
if not audio_path:
return "Please upload an audio file."
# get audio array
audio_array, sample_rate = torchaudio.load(audio_path)
# if sample rate is not 16000, resample to 16000
if sample_rate != 16000:
audio_array = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio_array)
#audio_array = audio_array.to(device)
inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
#inputs = inputs.to(device, dtype=torch.bfloat16)
with torch.no_grad():
logits = model(**inputs).logits
outputs = torch.argmax(logits, dim=-1)
decoded_outputs = processor.batch_decode(
outputs,
skip_special_tokens=True
)
return decoded_outputs[0].strip()
# Define Gradio interface
with gr.Blocks(title="Voxtral Demo") as demo:
gr.Markdown("# Swahili-ASR πŸŽ™οΈ Speech Recognition for Swahili Language πŸ₯₯")
#gr.Markdown("Developed with ❀ by [Badr al-Absi](https://badrex.github.io/)")
gr.Markdown(
'Developed with <span style="color:red;">❀</span> by <a href="https://badrex.github.io/">Badr al-Absi</a>'
)
gr.Markdown(
"""### Hi there πŸ‘‹πŸΌ
This is a demo for [badrex/w2v-bert-2.0-swahili-asr](https://huggingface.co/badrex/w2v-bert-2.0-swahili-asr),
a robust Transformer-based automatic speech recognition (ASR) system for Swahili language that was trained on 400+ hours of human-transcribed speech.
"""
)
gr.Markdown("Simply **upload an audio file** πŸ“€ or **record yourself speaking** πŸŽ™οΈβΊοΈ to try out the model!")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(type="filepath", label="Upload Audio")
submit_btn = gr.Button("Transcribe Audio", variant="primary")
with gr.Column():
output_text = gr.Textbox(label="Text Transcription", lines=10)
submit_btn.click(
fn=process_audio,
inputs=[audio_input],
outputs=output_text
)
gr.Examples(
examples=examples if examples else None,
inputs=[audio_input],
)
# Launch the app
if __name__ == "__main__":
demo.queue().launch() #share=False, ssr_mode=False, mcp_server=True
# demo = gr.Interface(
# fn=transcribe,
# inputs=gr.Audio(),
# outputs="text",
# title="<div></div>",
# description="""
# <div class="centered-content">
# <div>
# <p>
# Developed with ❀ by <a href="https://badrex.github.io/" style="color: #2563eb;">Badr al-Absi</a> β˜•
# </p>
# <br>
# <p style="font-size: 15px; line-height: 1.8;">
# Hi there πŸ‘‹πŸΌ
# <br>
# <br>
# This is a demo for <a href="https://huggingface.co/badrex/w2v-bert-2.0-swahili-asr" style="color: #2563eb;"> badrex/w2v-bert-2.0-swahili-asr</a>, a robust Transformer-based automatic speech recognition (ASR) system for Swahili language.
# The underlying ASR model was trained on more than 400 hours of transcribed speech.
# <br>
# <p style="font-size: 15px; line-height: 1.8;">
# Simply <strong>upload an audio file</strong> πŸ“€ or <strong>record yourself speaking</strong> πŸŽ™οΈβΊοΈ to try out the model!
# </p>
# </div>
# </div>
# """,
# examples=examples if examples else None,
# cache_examples=False,
# flagging_mode=None,
# )
# if __name__ == "__main__":
# demo.launch()