File size: 4,804 Bytes
fda6e40
866ddd2
67c1d53
fda6e40
67c1d53
 
fda6e40
67c1d53
 
fda6e40
67c1d53
 
 
 
 
 
 
e2dd32d
df23ecf
 
 
 
4f2565b
df23ecf
 
 
4f2565b
df23ecf
67c1d53
df23ecf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f2565b
df23ecf
 
4f2565b
df23ecf
4f2565b
df23ecf
 
4f2565b
df23ecf
e2dd32d
df23ecf
 
 
 
e2dd32d
df23ecf
fda6e40
 
67c1d53
 
d982d70
67c1d53
 
 
 
 
 
 
 
 
 
 
 
 
fda6e40
67c1d53
 
 
 
 
 
 
 
 
 
 
 
 
fda6e40
67c1d53
 
 
 
 
 
fda6e40
67c1d53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import torchaudio
import gradio as gr
import spaces
import torch
from transformers import AutoProcessor, AutoModelForCTC

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# load examples 
examples = []
examples_dir = "examples"
if os.path.exists(examples_dir):
    for filename in os.listdir(examples_dir):
        if filename.endswith((".wav", ".mp3", ".ogg")):
            examples.append([os.path.join(examples_dir, filename)])

# Load model and processor
MODEL_PATH = "badrex/w2v-bert-2.0-swahili-asr"
processor = AutoProcessor.from_pretrained(MODEL_PATH)
model = AutoModelForCTC.from_pretrained(MODEL_PATH)

# move model and processor to device
model = model.to(device)
#processor = processor.to(device)

@spaces.GPU()
def process_audio(audio_path):
    """Process audio with return the generated respotextnse.
    
    Args:
        audio_path: Path to the audio file to be transcribed.    
    Returns:
        String containing the transcribed text from the audio file, or an error message
        if the audio file is missing.
    """
    if not audio_path:
        return "Please upload an audio file."

    # get audio array
    audio_array, sample_rate = torchaudio.load(audio_path)

    # if sample rate is not 16000, resample to 16000
    if sample_rate != 16000:
        audio_array = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio_array)

    #audio_array = audio_array.to(device)

    inputs = processor(audio_array, sampling_rate=16000, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}

    #inputs = inputs.to(device, dtype=torch.bfloat16)

    with torch.no_grad():
        logits = model(**inputs).logits

    outputs = torch.argmax(logits, dim=-1)
    
    decoded_outputs = processor.batch_decode(
        outputs,
        skip_special_tokens=True
    )
    
    return decoded_outputs[0].strip()


# Define Gradio interface
with gr.Blocks(title="Voxtral Demo") as demo:
    gr.Markdown("# Swahili-ASR 🎙️ Speech Recognition for Swahili Language 🥥")
    #gr.Markdown("Developed with ❤ by [Badr al-Absi](https://badrex.github.io/)")
    gr.Markdown(    
        'Developed with <span style="color:red;">❤</span> by <a href="https://badrex.github.io/">Badr al-Absi</a>'
    )
    gr.Markdown(
        """### Hi there 👋🏼

This is a demo for [badrex/w2v-bert-2.0-swahili-asr](https://huggingface.co/badrex/w2v-bert-2.0-swahili-asr), 
a robust Transformer-based automatic speech recognition (ASR) system for Swahili language that was trained on 400+ hours of human-transcribed speech.
    """
    )

    gr.Markdown("Simply **upload an audio file** 📤 or **record yourself speaking** 🎙️⏺️ to try out the model!")
    
    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(type="filepath", label="Upload Audio")
            submit_btn = gr.Button("Transcribe Audio", variant="primary")
        
        with gr.Column():
            output_text = gr.Textbox(label="Text Transcription", lines=10)
    
    submit_btn.click(
        fn=process_audio,
        inputs=[audio_input],
        outputs=output_text
    )

    gr.Examples(
        examples=examples if examples else None,
        inputs=[audio_input],
    )

# Launch the app
if __name__ == "__main__":
    demo.queue().launch() #share=False, ssr_mode=False, mcp_server=True

    
# demo = gr.Interface(
#     fn=transcribe,
#     inputs=gr.Audio(),
#     outputs="text",
#     title="<div></div>",
#     description="""
#         <div class="centered-content">
#             <div>
#                 <p>
#                 Developed with ❤ by <a href="https://badrex.github.io/" style="color: #2563eb;">Badr al-Absi</a> ☕
#                 </p>
#                 <br>
#                 <p style="font-size: 15px; line-height: 1.8;">
#                  Hi there 👋🏼
#                 <br>
#                 <br>
#                  This is a demo for <a href="https://huggingface.co/badrex/w2v-bert-2.0-swahili-asr" style="color: #2563eb;"> badrex/w2v-bert-2.0-swahili-asr</a>, a robust Transformer-based automatic speech recognition (ASR) system for Swahili language.
#                  The underlying ASR model was trained on more than 400 hours of transcribed speech.
#                 <br>                   
#                 <p style="font-size: 15px; line-height: 1.8;">
#                 Simply <strong>upload an audio file</strong> 📤 or <strong>record yourself speaking</strong> 🎙️⏺️ to try out the model!
#                 </p>
#             </div>
#         </div>
#         """,
#     examples=examples if examples else None,
#     cache_examples=False,  
#     flagging_mode=None,
# )

# if __name__ == "__main__":
#     demo.launch()