AvtnshM commited on
Commit
d9e9840
·
verified ·
1 Parent(s): 4924252
Files changed (1) hide show
  1. app.py +58 -69
app.py CHANGED
@@ -1,87 +1,76 @@
1
- import time
2
  import os
 
3
  import evaluate
4
- import gradio as gr
5
  from datasets import load_dataset
6
- from transformers import pipeline
 
 
 
 
 
 
7
 
8
- # -----------------
9
- # Load evaluation metrics
10
  wer_metric = evaluate.load("wer")
11
  cer_metric = evaluate.load("cer")
12
 
13
- # -----------------
14
- # Small sample dataset for Hindi
15
- test_ds = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test[:3]")
 
 
 
16
 
17
- # Extract references + audio
18
- refs = [x["sentence"] for x in test_ds]
19
- audio_data = [x["audio"]["array"] for x in test_ds]
20
 
21
- # -----------------
22
- # Helper to evaluate model
23
- def evaluate_model(model_name, pipeline_kwargs=None):
24
  try:
25
- start = time.time()
26
- asr_pipeline = pipeline(
27
  "automatic-speech-recognition",
28
- model=model_name,
29
- device=-1, # CPU only
30
- **(pipeline_kwargs or {})
 
31
  )
32
 
33
- preds = []
34
- for audio in audio_data:
35
- out = asr_pipeline(audio, chunk_length_s=30, return_timestamps=False)
36
- preds.append(out["text"])
37
 
38
- end = time.time()
39
- rtf = (end - start) / sum(len(a) / 16000 for a in audio_data)
 
40
 
41
- return {
42
- "WER": wer_metric.compute(predictions=preds, references=refs),
43
- "CER": cer_metric.compute(predictions=preds, references=refs),
44
- "RTF": rtf
45
- }
46
 
47
- except Exception as e:
48
- return {"Error": str(e)}
 
 
 
 
 
 
49
 
50
- # -----------------
51
- # Models to test
52
- models = {
53
- "IndicConformer (AI4Bharat)": {
54
- "name": "ai4bharat/IndicConformer-Hi",
55
- "pipeline_kwargs": {"trust_remote_code": True}
56
- },
57
- "AudioX-North (Jivi AI)": {
58
- "name": "jiviai/audioX-north-v1",
59
- "pipeline_kwargs": {"use_auth_token": os.environ.get("HF_TOKEN")}
60
- },
61
- "MMS (Facebook)": {
62
- "name": "facebook/mms-1b-all",
63
- "pipeline_kwargs": {}
64
- }
65
- }
66
-
67
- # -----------------
68
- # Gradio interface
69
- def run_evaluations():
70
- rows = []
71
- for label, cfg in models.items():
72
- res = evaluate_model(cfg["name"], cfg["pipeline_kwargs"])
73
- if "Error" in res:
74
- rows.append([label, res["Error"], "-", "-"])
75
- else:
76
- rows.append([label, f"{res['WER']:.3f}", f"{res['CER']:.3f}", f"{res['RTF']:.2f}"])
77
- return rows
78
-
79
- with gr.Blocks() as demo:
80
- gr.Markdown("## ASR Benchmark Comparison (Hindi Sample)\nEvaluating **WER, CER, RTF** across models.")
81
- btn = gr.Button("Run Evaluation")
82
- table = gr.Dataframe(headers=["Model", "WER", "CER", "RTF"], datatype=["str", "str", "str", "str"], interactive=False)
83
-
84
- btn.click(fn=run_evaluations, outputs=table)
85
 
86
- if __name__ == "__main__":
87
- demo.launch()
 
 
 
 
1
  import os
2
+ import time
3
  import evaluate
4
+ import pandas as pd
5
  from datasets import load_dataset
6
+ from transformers import pipeline, AutoProcessor, AutoModelForCTC
7
+
8
+ # Get HF token from secret (for gated repos like Jivi)
9
+ hf_token = os.getenv("HF_TOKEN")
10
+
11
+ # Load Hindi dataset (tiny sample for speed)
12
+ test_ds = load_dataset("mozilla-foundation/common_voice_11_0_hi", split="test[:3]")
13
 
14
+ # Metrics
 
15
  wer_metric = evaluate.load("wer")
16
  cer_metric = evaluate.load("cer")
17
 
18
+ # Models to compare
19
+ models = {
20
+ "IndicConformer (AI4Bharat)": "ai4bharat/IndicConformer-hi",
21
+ "AudioX-North (Jivi AI)": "jiviai/audioX-north-v1",
22
+ "MMS (Facebook)": "facebook/mms-1b-all"
23
+ }
24
 
25
+ results = []
 
 
26
 
27
+ for model_name, model_id in models.items():
28
+ print(f"\n🔹 Running {model_name} ...")
 
29
  try:
30
+ # Init pipeline
31
+ asr = pipeline(
32
  "automatic-speech-recognition",
33
+ model=model_id,
34
+ tokenizer=model_id,
35
+ feature_extractor=model_id,
36
+ use_auth_token=hf_token if "jiviai" in model_id else None
37
  )
38
 
39
+ # Test loop
40
+ for sample in test_ds:
41
+ audio = sample["audio"]["array"]
42
+ ref_text = sample["sentence"]
43
 
44
+ start_time = time.time()
45
+ pred_text = asr(audio)["text"]
46
+ elapsed = time.time() - start_time
47
 
48
+ # Metrics
49
+ wer = wer_metric.compute(predictions=[pred_text], references=[ref_text])
50
+ cer = cer_metric.compute(predictions=[pred_text], references=[ref_text])
51
+ rtf = elapsed / (len(audio) / 16000) # real-time factor (audio length at 16kHz)
 
52
 
53
+ results.append({
54
+ "Model": model_name,
55
+ "Reference": ref_text,
56
+ "Prediction": pred_text,
57
+ "WER": round(wer, 3),
58
+ "CER": round(cer, 3),
59
+ "RTF": round(rtf, 3)
60
+ })
61
 
62
+ except Exception as e:
63
+ results.append({
64
+ "Model": model_name,
65
+ "Reference": "-",
66
+ "Prediction": "-",
67
+ "WER": None,
68
+ "CER": None,
69
+ "RTF": None,
70
+ "Error": str(e)
71
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ # Convert results to DataFrame
74
+ df = pd.DataFrame(results)
75
+ print("\n===== Final Comparison =====")
76
+ print(df.to_string(index=False))