V10
Browse files
app.py
CHANGED
|
@@ -1,76 +1,95 @@
|
|
| 1 |
import os
|
| 2 |
import time
|
| 3 |
-
import
|
| 4 |
-
import
|
| 5 |
from datasets import load_dataset
|
| 6 |
-
from transformers import pipeline
|
|
|
|
| 7 |
|
| 8 |
-
# Get
|
| 9 |
hf_token = os.getenv("HF_TOKEN")
|
| 10 |
|
| 11 |
-
# Load Hindi
|
| 12 |
-
test_ds = load_dataset(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# Metrics
|
| 15 |
wer_metric = evaluate.load("wer")
|
| 16 |
cer_metric = evaluate.load("cer")
|
| 17 |
|
| 18 |
-
# Models to
|
| 19 |
-
|
| 20 |
-
"IndicConformer (AI4Bharat)":
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
}
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
tokenizer=model_id,
|
| 35 |
-
feature_extractor=model_id,
|
| 36 |
-
use_auth_token=hf_token if "jiviai" in model_id else None
|
| 37 |
-
)
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
elapsed = time.time() - start_time
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
|
|
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
})
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
print("\n===== Final Comparison =====")
|
| 76 |
-
print(df.to_string(index=False))
|
|
|
|
| 1 |
import os
|
| 2 |
import time
|
| 3 |
+
import torch
|
| 4 |
+
import gradio as gr
|
| 5 |
from datasets import load_dataset
|
| 6 |
+
from transformers import pipeline
|
| 7 |
+
import evaluate
|
| 8 |
|
| 9 |
+
# Get token for gated repos
|
| 10 |
hf_token = os.getenv("HF_TOKEN")
|
| 11 |
|
| 12 |
+
# Load Hindi audio samples (Common Voice Hindi test subset)
|
| 13 |
+
test_ds = load_dataset(
|
| 14 |
+
"mozilla-foundation/common_voice_11_0",
|
| 15 |
+
"hi",
|
| 16 |
+
split="test[:3]",
|
| 17 |
+
use_auth_token=hf_token,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# Prepare references and audio arrays
|
| 21 |
+
refs = [sample["sentence"] for sample in test_ds]
|
| 22 |
+
audio_samples = [sample["audio"]["array"] for sample in test_ds]
|
| 23 |
|
| 24 |
# Metrics
|
| 25 |
wer_metric = evaluate.load("wer")
|
| 26 |
cer_metric = evaluate.load("cer")
|
| 27 |
|
| 28 |
+
# Models to test
|
| 29 |
+
MODELS = {
|
| 30 |
+
"IndicConformer (AI4Bharat)": {
|
| 31 |
+
"model_id": "ai4bharat/indic-conformer-600m-multilingual",
|
| 32 |
+
"trust_remote_code": True,
|
| 33 |
+
"auth": None
|
| 34 |
+
},
|
| 35 |
+
"AudioX-North (Jivi AI)": {
|
| 36 |
+
"model_id": "jiviai/audioX-north-v1",
|
| 37 |
+
"trust_remote_code": False,
|
| 38 |
+
"auth": hf_token
|
| 39 |
+
},
|
| 40 |
+
"MMS (Facebook)": {
|
| 41 |
+
"model_id": "facebook/mms-1b-all",
|
| 42 |
+
"trust_remote_code": False,
|
| 43 |
+
"auth": None
|
| 44 |
+
}
|
| 45 |
}
|
| 46 |
|
| 47 |
+
def eval_model(model_info):
|
| 48 |
+
args = {
|
| 49 |
+
"model": model_info["model_id"],
|
| 50 |
+
"device": -1 # CPU only
|
| 51 |
+
}
|
| 52 |
+
if model_info["trust_remote_code"]:
|
| 53 |
+
args["trust_remote_code"] = True
|
| 54 |
+
if model_info["auth"]:
|
| 55 |
+
args["use_auth_token"] = model_info["auth"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
+
asr = pipeline("automatic-speech-recognition", **args)
|
| 58 |
+
preds = []
|
| 59 |
+
start = time.time()
|
| 60 |
+
for audio in audio_samples:
|
| 61 |
+
out = asr(audio)
|
| 62 |
+
preds.append(out["text"].strip())
|
| 63 |
+
elapsed = time.time() - start
|
| 64 |
|
| 65 |
+
total_len = sum(len(a) for a in audio_samples) / 16000
|
| 66 |
+
rtf = elapsed / total_len
|
|
|
|
| 67 |
|
| 68 |
+
return {
|
| 69 |
+
"WER": wer_metric.compute(predictions=preds, references=refs),
|
| 70 |
+
"CER": cer_metric.compute(predictions=preds, references=refs),
|
| 71 |
+
"RTF": rtf
|
| 72 |
+
}
|
| 73 |
|
| 74 |
+
def run_all():
|
| 75 |
+
rows = []
|
| 76 |
+
for name, cfg in MODELS.items():
|
| 77 |
+
try:
|
| 78 |
+
res = eval_model(cfg)
|
| 79 |
+
rows.append([name, f"{res['WER']:.3f}", f"{res['CER']:.3f}", f"{res['RTF']:.2f}"])
|
| 80 |
+
except Exception as e:
|
| 81 |
+
rows.append([name, "Error", "Error", "Error"])
|
| 82 |
+
return rows
|
| 83 |
|
| 84 |
+
with gr.Blocks() as demo:
|
| 85 |
+
gr.Markdown("### ASR Model Benchmark (Hindi Samples)\nWER, CER, and RTF comparison.")
|
| 86 |
+
btn = gr.Button("Run Benchmark")
|
| 87 |
+
table = gr.Dataframe(
|
| 88 |
+
headers=["Model", "WER", "CER", "RTF"],
|
| 89 |
+
datatype=["str", "str", "str", "str"],
|
| 90 |
+
interactive=False
|
| 91 |
+
)
|
| 92 |
+
btn.click(run_all, outputs=table)
|
|
|
|
| 93 |
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
demo.launch()
|
|
|
|
|
|