AvtnshM commited on
Commit
64a9e0f
·
verified ·
1 Parent(s): 156d331
Files changed (1) hide show
  1. app.py +68 -70
app.py CHANGED
@@ -1,82 +1,80 @@
1
  import time
2
- import torch
3
- import gradio as gr
4
- import torchaudio
5
- from transformers import (
6
- WhisperProcessor, WhisperForConditionalGeneration,
7
- AutoProcessor, AutoModelForCTC, pipeline
8
- )
9
- from jiwer import wer, cer
10
 
11
- # Utility to load audio and resample to 16 kHz
12
- def load_audio(fp):
13
- waveform, sr = torchaudio.load(fp)
14
- if sr != 16000:
15
- waveform = torchaudio.transforms.Resample(sr, 16000)(waveform)
16
- return waveform.squeeze(0), 16000
17
 
18
- # Evaluation function
19
- def eval_model(name, cfg, file, ref):
20
- waveform, sr = load_audio(file)
21
- start = time.time()
22
 
23
- if cfg["type"] == "whisper":
24
- proc = WhisperProcessor.from_pretrained(cfg["id"])
25
- model = WhisperForConditionalGeneration.from_pretrained(cfg["id"])
26
- pipe = pipeline(
27
- "automatic-speech-recognition",
28
- model=model,
29
- tokenizer=proc.tokenizer,
30
- feature_extractor=proc.feature_extractor,
31
- device=-1
32
- )
33
- else:
34
- proc = AutoProcessor.from_pretrained(cfg["id"], trust_remote_code=True)
35
- model = AutoModelForCTC.from_pretrained(cfg["id"], trust_remote_code=True)
36
- pipe = pipeline(
 
 
 
37
  "automatic-speech-recognition",
38
- model=model,
39
- tokenizer=proc.tokenizer,
40
- feature_extractor=proc.feature_extractor,
41
- device=-1
42
  )
43
 
44
- result = pipe(waveform)
45
- hyp = result["text"].lower()
46
- w = wer(ref.lower() if ref else "", hyp) if ref else None
47
- c = cer(ref.lower() if ref else "", hyp) if ref else None
48
- rtf = (time.time() - start) / (waveform.shape[0] / sr)
49
 
50
- return {"Transcription": hyp, "WER": w, "CER": c, "RTF": rtf}
 
51
 
52
- # Model configs
53
- MODELS = {
54
- "IndicConformer (AI4Bharat)": {"id": "ai4bharat/indic-conformer-600m-multilingual", "type": "conformer"},
55
- "AudioX-North (Jivi AI)": {"id": "jiviai/audioX-north-v1", "type": "whisper"},
56
- "MMS (Facebook)": {"id": "facebook/mms-1b-all", "type": "conformer"},
57
- }
58
 
59
- # Gradio interface logic
60
- def compare_all(audio, reference, language):
61
- results = {}
62
- for name, cfg in MODELS.items():
63
- try:
64
- results[name] = eval_model(name, cfg, audio, reference)
65
- except Exception as e:
66
- results[name] = {"Error": str(e)}
67
- return results
 
 
 
 
 
 
 
 
 
 
68
 
69
- demo = gr.Interface(
70
- fn=compare_all,
71
- inputs=[
72
- gr.Audio(type="filepath", label="Upload Audio (<=20s recommended)"),
73
- gr.Textbox(label="Reference Transcript (optional)"),
74
- gr.Dropdown(choices=["hi","gu","ta"], label="Language", value="hi")
75
- ],
76
- outputs=gr.JSON(label="Benchmark Results"),
77
- title="Indic ASR Benchmark (CPU-only)",
78
- description="Compare IndicConformer, AudioX-North, and MMS on WER, CER, and RTF."
79
- )
80
 
81
- if __name__ == "__main__":
82
- demo.launch()
 
1
  import time
2
+ import os
3
+ import evaluate
4
+ from datasets import load_dataset
5
+ from huggingface_hub import login
6
+ from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
 
 
 
7
 
8
+ # 🔑 Authenticate using HF_TOKEN secret
9
+ login(token=os.environ.get("HF_TOKEN"))
 
 
 
 
10
 
11
+ # -----------------
12
+ # Load evaluation metrics
13
+ wer_metric = evaluate.load("wer")
14
+ cer_metric = evaluate.load("cer")
15
 
16
+ # -----------------
17
+ # Small sample dataset for Hindi
18
+ # (free Spaces can't handle large test sets)
19
+ test_ds = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test[:3]")
20
+
21
+ # Extract references + audio
22
+ refs = [x["sentence"] for x in test_ds]
23
+ audio_data = [x["audio"]["array"] for x in test_ds]
24
+
25
+ results = {}
26
+
27
+ # -----------------
28
+ # Helper to evaluate model
29
+ def evaluate_model(model_name, pipeline_kwargs=None):
30
+ try:
31
+ start = time.time()
32
+ asr_pipeline = pipeline(
33
  "automatic-speech-recognition",
34
+ model=model_name,
35
+ device=-1, # CPU only
36
+ **(pipeline_kwargs or {})
 
37
  )
38
 
39
+ preds = []
40
+ for audio in audio_data:
41
+ out = asr_pipeline(audio, chunk_length_s=30, return_timestamps=False)
42
+ preds.append(out["text"])
 
43
 
44
+ end = time.time()
45
+ rtf = (end - start) / sum(len(a) / 16000 for a in audio_data)
46
 
47
+ return {
48
+ "Transcriptions": preds,
49
+ "WER": wer_metric.compute(predictions=preds, references=refs),
50
+ "CER": cer_metric.compute(predictions=preds, references=refs),
51
+ "RTF": rtf
52
+ }
53
 
54
+ except Exception as e:
55
+ return {"Error": str(e)}
56
+
57
+ # -----------------
58
+ # Models to test
59
+ models = {
60
+ "IndicConformer (AI4Bharat)": {
61
+ "name": "ai4bharat/IndicConformer-Hi",
62
+ "pipeline_kwargs": {"trust_remote_code": True}
63
+ },
64
+ "AudioX-North (Jivi AI)": {
65
+ "name": "jiviai/audioX-north-v1",
66
+ "pipeline_kwargs": {"use_auth_token": os.environ.get("HF_TOKEN")}
67
+ },
68
+ "MMS (Facebook)": {
69
+ "name": "facebook/mms-1b-all",
70
+ "pipeline_kwargs": {}
71
+ }
72
+ }
73
 
74
+ # -----------------
75
+ # Run evaluations
76
+ for label, cfg in models.items():
77
+ print(f"Running {label}...")
78
+ results[label] = evaluate_model(cfg["name"], cfg["pipeline_kwargs"])
 
 
 
 
 
 
79
 
80
+ print(results)