AvtnshM commited on
Commit
1aa3ed3
·
verified ·
1 Parent(s): d9e9840
Files changed (1) hide show
  1. app.py +76 -57
app.py CHANGED
@@ -1,76 +1,95 @@
1
  import os
2
  import time
3
- import evaluate
4
- import pandas as pd
5
  from datasets import load_dataset
6
- from transformers import pipeline, AutoProcessor, AutoModelForCTC
 
7
 
8
- # Get HF token from secret (for gated repos like Jivi)
9
  hf_token = os.getenv("HF_TOKEN")
10
 
11
- # Load Hindi dataset (tiny sample for speed)
12
- test_ds = load_dataset("mozilla-foundation/common_voice_11_0_hi", split="test[:3]")
 
 
 
 
 
 
 
 
 
13
 
14
  # Metrics
15
  wer_metric = evaluate.load("wer")
16
  cer_metric = evaluate.load("cer")
17
 
18
- # Models to compare
19
- models = {
20
- "IndicConformer (AI4Bharat)": "ai4bharat/IndicConformer-hi",
21
- "AudioX-North (Jivi AI)": "jiviai/audioX-north-v1",
22
- "MMS (Facebook)": "facebook/mms-1b-all"
 
 
 
 
 
 
 
 
 
 
 
 
23
  }
24
 
25
- results = []
26
-
27
- for model_name, model_id in models.items():
28
- print(f"\n🔹 Running {model_name} ...")
29
- try:
30
- # Init pipeline
31
- asr = pipeline(
32
- "automatic-speech-recognition",
33
- model=model_id,
34
- tokenizer=model_id,
35
- feature_extractor=model_id,
36
- use_auth_token=hf_token if "jiviai" in model_id else None
37
- )
38
 
39
- # Test loop
40
- for sample in test_ds:
41
- audio = sample["audio"]["array"]
42
- ref_text = sample["sentence"]
 
 
 
43
 
44
- start_time = time.time()
45
- pred_text = asr(audio)["text"]
46
- elapsed = time.time() - start_time
47
 
48
- # Metrics
49
- wer = wer_metric.compute(predictions=[pred_text], references=[ref_text])
50
- cer = cer_metric.compute(predictions=[pred_text], references=[ref_text])
51
- rtf = elapsed / (len(audio) / 16000) # real-time factor (audio length at 16kHz)
 
52
 
53
- results.append({
54
- "Model": model_name,
55
- "Reference": ref_text,
56
- "Prediction": pred_text,
57
- "WER": round(wer, 3),
58
- "CER": round(cer, 3),
59
- "RTF": round(rtf, 3)
60
- })
 
61
 
62
- except Exception as e:
63
- results.append({
64
- "Model": model_name,
65
- "Reference": "-",
66
- "Prediction": "-",
67
- "WER": None,
68
- "CER": None,
69
- "RTF": None,
70
- "Error": str(e)
71
- })
72
 
73
- # Convert results to DataFrame
74
- df = pd.DataFrame(results)
75
- print("\n===== Final Comparison =====")
76
- print(df.to_string(index=False))
 
1
  import os
2
  import time
3
+ import torch
4
+ import gradio as gr
5
  from datasets import load_dataset
6
+ from transformers import pipeline
7
+ import evaluate
8
 
9
+ # Get token for gated repos
10
  hf_token = os.getenv("HF_TOKEN")
11
 
12
+ # Load Hindi audio samples (Common Voice Hindi test subset)
13
+ test_ds = load_dataset(
14
+ "mozilla-foundation/common_voice_11_0",
15
+ "hi",
16
+ split="test[:3]",
17
+ use_auth_token=hf_token,
18
+ )
19
+
20
+ # Prepare references and audio arrays
21
+ refs = [sample["sentence"] for sample in test_ds]
22
+ audio_samples = [sample["audio"]["array"] for sample in test_ds]
23
 
24
  # Metrics
25
  wer_metric = evaluate.load("wer")
26
  cer_metric = evaluate.load("cer")
27
 
28
+ # Models to test
29
+ MODELS = {
30
+ "IndicConformer (AI4Bharat)": {
31
+ "model_id": "ai4bharat/indic-conformer-600m-multilingual",
32
+ "trust_remote_code": True,
33
+ "auth": None
34
+ },
35
+ "AudioX-North (Jivi AI)": {
36
+ "model_id": "jiviai/audioX-north-v1",
37
+ "trust_remote_code": False,
38
+ "auth": hf_token
39
+ },
40
+ "MMS (Facebook)": {
41
+ "model_id": "facebook/mms-1b-all",
42
+ "trust_remote_code": False,
43
+ "auth": None
44
+ }
45
  }
46
 
47
+ def eval_model(model_info):
48
+ args = {
49
+ "model": model_info["model_id"],
50
+ "device": -1 # CPU only
51
+ }
52
+ if model_info["trust_remote_code"]:
53
+ args["trust_remote_code"] = True
54
+ if model_info["auth"]:
55
+ args["use_auth_token"] = model_info["auth"]
 
 
 
 
56
 
57
+ asr = pipeline("automatic-speech-recognition", **args)
58
+ preds = []
59
+ start = time.time()
60
+ for audio in audio_samples:
61
+ out = asr(audio)
62
+ preds.append(out["text"].strip())
63
+ elapsed = time.time() - start
64
 
65
+ total_len = sum(len(a) for a in audio_samples) / 16000
66
+ rtf = elapsed / total_len
 
67
 
68
+ return {
69
+ "WER": wer_metric.compute(predictions=preds, references=refs),
70
+ "CER": cer_metric.compute(predictions=preds, references=refs),
71
+ "RTF": rtf
72
+ }
73
 
74
+ def run_all():
75
+ rows = []
76
+ for name, cfg in MODELS.items():
77
+ try:
78
+ res = eval_model(cfg)
79
+ rows.append([name, f"{res['WER']:.3f}", f"{res['CER']:.3f}", f"{res['RTF']:.2f}"])
80
+ except Exception as e:
81
+ rows.append([name, "Error", "Error", "Error"])
82
+ return rows
83
 
84
+ with gr.Blocks() as demo:
85
+ gr.Markdown("### ASR Model Benchmark (Hindi Samples)\nWER, CER, and RTF comparison.")
86
+ btn = gr.Button("Run Benchmark")
87
+ table = gr.Dataframe(
88
+ headers=["Model", "WER", "CER", "RTF"],
89
+ datatype=["str", "str", "str", "str"],
90
+ interactive=False
91
+ )
92
+ btn.click(run_all, outputs=table)
 
93
 
94
+ if __name__ == "__main__":
95
+ demo.launch()