AvtnshM commited on
Commit
d0422d9
·
verified ·
1 Parent(s): 4e42095
Files changed (1) hide show
  1. app.py +93 -624
app.py CHANGED
@@ -1,653 +1,122 @@
1
  import gradio as gr
2
  import torch
3
  import torchaudio
4
- import numpy as np
5
- from transformers import (
6
- AutoProcessor,
7
- AutoModelForSpeechSeq2Seq,
8
- AutoModelForCTC,
9
- Wav2Vec2Processor,
10
- Wav2Vec2ForCTC
11
- )
12
  import librosa
 
 
13
  import time
14
- import os
15
- from typing import Dict, Tuple, Optional
16
- import jiwer
17
- import warnings
18
- warnings.filterwarnings("ignore")
19
 
20
  # Model configurations
21
- MODELS_CONFIG = {
22
- "IndicConformer-600M": {
23
- "repo_id": "ai4bharat/indic-conformer-600m-multilingual",
24
- "type": "conformer",
25
- "params": "600M",
26
- "languages": "22 Indian languages (Hindi, Bengali, Gujarati, Marathi, Tamil, Telugu, Kannada, Malayalam, etc.)",
27
- "architecture": "Multilingual Conformer-based Hybrid CTC + RNNT",
28
- "license": "MIT",
29
- "description": "AI4Bharat's comprehensive ASR model for all 22 official Indian languages"
30
- },
31
- "AudioX-North": {
32
- "repo_id": "placeholder/audiox-north", # Replace with actual repo when available
33
- "type": "audiox",
34
- "params": "Unknown",
35
- "languages": "Hindi, Gujarati, Marathi",
36
- "architecture": "Fine-tuned ASR with domain adaptation",
37
- "license": "Unknown",
38
- "description": "Jivi AI's specialized model for North Indian languages"
39
  },
40
- "AudioX-South": {
41
- "repo_id": "placeholder/audiox-south", # Replace with actual repo when available
42
- "type": "audiox",
43
- "params": "Unknown",
44
- "languages": "Tamil, Telugu, Kannada, Malayalam",
45
- "architecture": "Fine-tuned ASR with domain adaptation",
46
- "license": "Unknown",
47
- "description": "Jivi AI's specialized model for South Indian languages"
48
  },
49
- "Facebook-MMS": {
50
- "repo_id": "facebook/mms-1b-all",
51
- "type": "mms",
52
- "params": "1B",
53
- "languages": "1400+ languages worldwide",
54
- "architecture": "Wav2Vec2 self-supervised pretraining",
55
- "license": "CC-BY-NC 4.0",
56
- "description": "Facebook's massive multilingual speech model"
57
  }
58
  }
59
 
60
- # Benchmark data from AudioX (Vistaar Benchmark)
61
- VISTAAR_BENCHMARK = {
62
- "Hindi": {"AudioX": 12.14, "ElevenLabs": 13.64, "Sarvam": 14.28, "IndicWhisper": 13.59, "Azure": 20.03, "GPT-4": 18.65, "Google": 23.89, "Whisper-v3": 32.00},
63
- "Gujarati": {"AudioX": 18.66, "ElevenLabs": 17.96, "Sarvam": 19.47, "IndicWhisper": 22.84, "Azure": 31.62, "GPT-4": 31.32, "Google": 36.48, "Whisper-v3": 53.75},
64
- "Marathi": {"AudioX": 18.68, "ElevenLabs": 16.51, "Sarvam": 18.34, "IndicWhisper": 18.25, "Azure": 27.36, "GPT-4": 25.21, "Google": 26.48, "Whisper-v3": 78.28},
65
- "Tamil": {"AudioX": 21.79, "ElevenLabs": 24.84, "Sarvam": 25.73, "IndicWhisper": 25.27, "Azure": 31.53, "GPT-4": 39.10, "Google": 33.62, "Whisper-v3": 52.44},
66
- "Telugu": {"AudioX": 24.63, "ElevenLabs": 24.89, "Sarvam": 26.80, "IndicWhisper": 28.82, "Azure": 31.38, "GPT-4": 33.94, "Google": 42.42, "Whisper-v3": 179.58},
67
- "Kannada": {"AudioX": 17.61, "ElevenLabs": 17.65, "Sarvam": 18.95, "IndicWhisper": 18.33, "Azure": 26.45, "GPT-4": 32.88, "Google": 31.48, "Whisper-v3": 67.02},
68
- "Malayalam": {"AudioX": 26.92, "ElevenLabs": 28.88, "Sarvam": 32.64, "IndicWhisper": 32.34, "Azure": 41.84, "GPT-4": 46.11, "Google": 47.90, "Whisper-v3": 142.98}
69
- }
70
-
71
- class ASRModelManager:
72
- def __init__(self):
73
- self.loaded_models = {}
74
- self.processors = {}
75
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
-
77
- def load_model(self, model_name: str) -> Tuple[object, object]:
78
- """Load model and processor with error handling"""
79
- if model_name in self.loaded_models:
80
- return self.loaded_models[model_name], self.processors[model_name]
81
-
82
- try:
83
- config = MODELS_CONFIG[model_name]
84
- repo_id = config["repo_id"]
85
- model_type = config["type"]
86
-
87
- if model_type == "conformer":
88
- # Load IndicConformer model
89
- processor = AutoProcessor.from_pretrained(repo_id)
90
- model = AutoModelForSpeechSeq2Seq.from_pretrained(
91
- repo_id,
92
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
93
- device_map="auto" if torch.cuda.is_available() else None
94
- )
95
-
96
- elif model_type == "mms":
97
- # Load Facebook MMS model
98
- processor = Wav2Vec2Processor.from_pretrained(repo_id)
99
- model = Wav2Vec2ForCTC.from_pretrained(repo_id)
100
- model = model.to(self.device)
101
-
102
- elif model_type == "audiox":
103
- # Placeholder for AudioX models - replace with actual implementation
104
- # For now, using a fallback model for demonstration
105
- processor = AutoProcessor.from_pretrained("openai/whisper-small")
106
- model = AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small")
107
- model = model.to(self.device)
108
-
109
- self.loaded_models[model_name] = model
110
- self.processors[model_name] = processor
111
-
112
- return model, processor
113
-
114
- except Exception as e:
115
- raise Exception(f"Failed to load {model_name}: {str(e)}")
116
-
117
- def preprocess_audio(audio_path: str, target_sr: int = 16000) -> Tuple[np.ndarray, int]:
118
- """Preprocess audio file for ASR inference"""
119
  try:
120
- # Load and resample audio
121
- audio, sr = librosa.load(audio_path, sr=target_sr)
122
-
123
- # Normalize audio to prevent clipping
124
- if np.max(np.abs(audio)) > 0:
125
- audio = audio / np.max(np.abs(audio)) * 0.95
126
-
127
- return audio, sr
128
-
129
  except Exception as e:
130
- raise Exception(f"Audio preprocessing failed: {str(e)}")
131
 
132
- def calculate_wer_cer(reference: str, hypothesis: str) -> Tuple[float, float]:
133
- """Calculate Word Error Rate and Character Error Rate"""
 
 
134
  try:
135
- # Calculate WER using jiwer
136
- wer = jiwer.wer(reference, hypothesis) * 100
137
-
138
- # Calculate CER
139
- cer = jiwer.cer(reference, hypothesis) * 100
140
-
141
- return wer, cer
142
-
143
- except Exception:
144
- return 0.0, 0.0
145
 
146
- def transcribe_audio(
147
- audio_file: str,
148
- model_name: str,
149
- reference_text: str = "",
150
- language: str = "auto"
151
- ) -> Tuple[str, str, float, float, float]:
152
- """Perform ASR transcription and calculate metrics"""
153
 
154
- if audio_file is None:
155
- return "❌ Please upload an audio file", "", 0.0, 0.0, 0.0
 
 
156
 
157
  try:
158
- # Start timing for RTF calculation
159
- start_time = time.time()
160
-
161
- # Preprocess audio
162
- audio, sr = preprocess_audio(audio_file)
163
  audio_duration = len(audio) / sr
164
 
165
- # Load model and processor
166
- model, processor = model_manager.load_model(model_name)
167
-
168
- # Perform transcription based on model type
169
- config = MODELS_CONFIG[model_name]
170
-
171
- if config["type"] == "conformer":
172
- # IndicConformer inference
173
- inputs = processor(
174
- audio,
175
- sampling_rate=sr,
176
- return_tensors="pt",
177
- padding=True
178
- )
179
-
180
- if torch.cuda.is_available():
181
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
182
-
183
- with torch.no_grad():
184
- predicted_ids = model.generate(**inputs, max_length=448)
185
-
186
- transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
187
-
188
- elif config["type"] == "mms":
189
- # Facebook MMS inference
190
- inputs = processor(
191
- audio,
192
- sampling_rate=sr,
193
- return_tensors="pt",
194
- padding=True
195
- )
196
-
197
- if torch.cuda.is_available():
198
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
199
-
200
- with torch.no_grad():
201
- logits = model(**inputs).logits
202
-
203
- predicted_ids = torch.argmax(logits, dim=-1)
204
- transcription = processor.decode(predicted_ids[0])
205
-
206
- elif config["type"] == "audiox":
207
- # AudioX placeholder implementation
208
- inputs = processor(
209
- audio,
210
- sampling_rate=sr,
211
- return_tensors="pt"
212
- )
213
-
214
- if torch.cuda.is_available():
215
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
216
-
217
- with torch.no_grad():
218
- predicted_ids = model.generate(**inputs, max_length=448)
219
-
220
- transcription = processor.decode(predicted_ids[0], skip_special_tokens=True)
221
-
222
- # Calculate processing time and RTF
223
- end_time = time.time()
224
- processing_time = end_time - start_time
225
- rtf = processing_time / audio_duration
226
-
227
- # Calculate WER and CER if reference provided
228
- wer, cer = 0.0, 0.0
229
- if reference_text.strip():
230
- wer, cer = calculate_wer_cer(reference_text.strip(), transcription.strip())
231
-
232
- # Format model info
233
- model_info = f"""
234
- 🤖 Model: {model_name}
235
- 📊 Parameters: {config['params']}
236
- 🗣️ Languages: {config['languages']}
237
- ⚙️ Architecture: {config['architecture']}
238
- ⏱️ Processing Time: {processing_time:.2f}s
239
- 🎵 Audio Duration: {audio_duration:.2f}s
240
- """
241
-
242
- return transcription.strip(), model_info, wer, cer, rtf
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  except Exception as e:
245
- return f" Error: {str(e)}", "", 0.0, 0.0, 0.0
246
-
247
- def create_benchmark_table():
248
- """Create the Vistaar benchmark comparison table"""
249
- # Headers
250
- headers = ["Language", "AudioX", "ElevenLabs", "Sarvam", "IndicWhisper", "Azure STT", "GPT-4", "Google STT", "Whisper-v3"]
251
-
252
- # Data rows
253
- rows = []
254
- for lang, scores in VISTAAR_BENCHMARK.items():
255
- row = [lang] + [f"{score:.2f}%" for score in scores.values()]
256
- rows.append(row)
257
-
258
- # Calculate and add average row
259
- avg_row = ["🏆 Average"]
260
- for provider in VISTAAR_BENCHMARK["Hindi"].keys():
261
- avg_score = np.mean([VISTAAR_BENCHMARK[lang][provider] for lang in VISTAAR_BENCHMARK.keys()])
262
- avg_row.append(f"{avg_score:.2f}%")
263
- rows.append(avg_row)
264
-
265
- return [headers] + rows
266
 
267
- def create_model_specs_table():
268
- """Create model specifications comparison table"""
269
- headers = ["Model", "Parameters", "Languages", "Architecture", "License", "Specialty"]
270
-
271
- rows = [
272
- ["IndicConformer-600M", "600M", "22 Indian", "Conformer CTC+RNNT", "MIT", "Comprehensive coverage"],
273
- ["AudioX-North", "Unknown", "Hindi, Gujarati, Marathi", "Fine-tuned ASR", "Unknown", "North Indian optimization"],
274
- ["AudioX-South", "Unknown", "Tamil, Telugu, Kannada, Malayalam", "Fine-tuned ASR", "Unknown", "South Indian optimization"],
275
- ["Facebook MMS", "1B", "1400+ Global", "Wav2Vec2", "CC-BY-NC 4.0", "Massive multilingual"]
276
- ]
277
-
278
- return [headers] + rows
279
-
280
- # Initialize model manager
281
- model_manager = ASRModelManager()
282
-
283
- # Create Gradio interface
284
- with gr.Blocks(
285
- title="🎯 ASR Model Comparison: IndicConformer vs AudioX vs MMS",
286
- theme=gr.themes.Soft(),
287
- css="""
288
- .performance-card {
289
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
290
- padding: 1rem;
291
- border-radius: 10px;
292
- color: white;
293
- margin: 0.5rem 0;
294
- }
295
- .metric-highlight {
296
- background: #f0f9ff;
297
- padding: 0.5rem;
298
- border-left: 4px solid #3b82f6;
299
- margin: 0.5rem 0;
300
- }
301
- """
302
- ) as demo:
303
-
304
- gr.Markdown("""
305
- # 🎯 Comprehensive ASR Model Comparison Dashboard
306
-
307
- Compare three cutting-edge Automatic Speech Recognition models for Indian languages:
308
-
309
- - 🇮🇳 **AI4Bharat IndicConformer-600M**: Complete 22 Indian language coverage
310
- - 🎯 **Jivi AI AudioX**: Specialized North/South variants with industry-leading accuracy
311
- - 🌍 **Facebook MMS**: Massive 1B parameter multilingual model
312
-
313
- ## 🏆 Key Highlight: AudioX achieves **20.1% average WER** - Best in class performance!
314
- """)
315
-
316
- with gr.Tabs():
317
-
318
- # Live Testing Tab
319
- with gr.TabItem("🎤 Live ASR Testing"):
320
- gr.Markdown("### Upload audio and test model performance in real-time")
321
-
322
- with gr.Row():
323
- with gr.Column(scale=1):
324
- audio_input = gr.Audio(
325
- label="📁 Upload Audio File",
326
- type="filepath",
327
- format="wav"
328
- )
329
-
330
- model_selector = gr.Dropdown(
331
- choices=list(MODELS_CONFIG.keys()),
332
- label="🤖 Select ASR Model",
333
- value="IndicConformer-600M",
334
- info="Choose the model for transcription"
335
- )
336
-
337
- reference_input = gr.Textbox(
338
- label="📝 Reference Text (Optional)",
339
- placeholder="Enter the correct transcription for accuracy calculation...",
340
- lines=3,
341
- info="Provide ground truth text to calculate WER and CER"
342
- )
343
-
344
- transcribe_button = gr.Button(
345
- "🚀 Transcribe Audio",
346
- variant="primary",
347
- size="lg"
348
- )
349
-
350
- with gr.Column(scale=1):
351
- transcription_output = gr.Textbox(
352
- label="📄 Transcription Result",
353
- lines=5,
354
- max_lines=8
355
- )
356
-
357
- model_info_output = gr.Textbox(
358
- label="ℹ️ Model Information",
359
- lines=7
360
- )
361
-
362
- with gr.Row():
363
- with gr.Column():
364
- wer_output = gr.Number(
365
- label="📊 Word Error Rate (WER %)",
366
- precision=2,
367
- info="Lower is better"
368
- )
369
- with gr.Column():
370
- cer_output = gr.Number(
371
- label="📊 Character Error Rate (CER %)",
372
- precision=2,
373
- info="Lower is better"
374
- )
375
- with gr.Column():
376
- rtf_output = gr.Number(
377
- label="⚡ Real-Time Factor (RTF)",
378
- precision=3,
379
- info="< 1.0 = faster than real-time"
380
- )
381
-
382
- # Benchmark Results Tab
383
- with gr.TabItem("📊 Vistaar Benchmark Results"):
384
- gr.Markdown("""
385
- ## 🏆 Official Vistaar Benchmark Comparison (WER %)
386
-
387
- Performance evaluation on AI4Bharat's standardized Vistaar benchmark across 7 Indian languages.
388
- **Lower WER indicates better accuracy** ⬇️
389
- """)
390
-
391
- benchmark_df = gr.Dataframe(
392
- value=create_benchmark_table(),
393
- label="📈 Word Error Rate Comparison",
394
- interactive=False,
395
- wrap=True
396
- )
397
-
398
- gr.Markdown("""
399
- ### 🎯 Key Performance Insights:
400
-
401
- | 🏅 Rank | Model | Avg WER | Strength |
402
- |---------|-------|---------|----------|
403
- | 🥇 1st | **AudioX** | **20.1%** | Consistently best across languages |
404
- | 🥈 2nd | ElevenLabs Scribe-v1 | 20.6% | Strong competitor, especially in Gujarati |
405
- | 🥉 3rd | Sarvam saarika:v2 | 22.3% | Solid performance across the board |
406
- | 4th | AI4Bharat IndicWhisper | 22.8% | Good baseline for comparison |
407
- | 5th | Microsoft Azure STT | 30.0% | Commercial solution performance |
408
-
409
- ### 💡 Analysis:
410
- - **AudioX dominates** in 5 out of 7 languages
411
- - **Specialized models outperform** general commercial solutions
412
- - **Malayalam and Telugu** are the most challenging languages across all models
413
- - **Hindi** shows the best performance across all models
414
- """)
415
-
416
- # Model Architecture Tab
417
- with gr.TabItem("⚙️ Model Architecture & Specs"):
418
- gr.Markdown("## 🔧 Technical Specifications Comparison")
419
-
420
- specs_df = gr.Dataframe(
421
- value=create_model_specs_table(),
422
- label="📋 Model Architecture Details",
423
- interactive=False
424
- )
425
-
426
- with gr.Row():
427
- with gr.Column():
428
- gr.Markdown("""
429
- ### 🎯 IndicConformer-600M
430
-
431
- **🏗️ Architecture**: Hybrid CTC + RNNT Conformer
432
- **🎯 Focus**: Comprehensive Indian language coverage
433
- **📊 Training**: Large-scale multilingual approach
434
- **⚡ Inference**: Dual decoding strategies
435
- **🎭 Use Cases**:
436
- - General-purpose Indian ASR
437
- - Research and development
438
- - Educational applications
439
-
440
- **✅ Strengths**:
441
- - Open-source MIT license
442
- - Covers all 22 official languages
443
- - Well-documented and accessible
444
- """)
445
-
446
- with gr.Column():
447
- gr.Markdown("""
448
- ### 🏆 AudioX Series
449
-
450
- **🏗️ Architecture**: Specialized fine-tuned models
451
- **🎯 Focus**: Language-specific optimization
452
- **📊 Training**: Open-source + proprietary medical data
453
- **⚡ Inference**: Optimized for production
454
- **🎭 Use Cases**:
455
- - Production voice assistants
456
- - Healthcare transcription
457
- - Customer service automation
458
- - Content creation platforms
459
-
460
- **✅ Strengths**:
461
- - Industry-leading accuracy
462
- - Regional accent handling
463
- - Robust to noise and variations
464
- """)
465
-
466
- with gr.Column():
467
- gr.Markdown("""
468
- ### 🌍 Facebook MMS
469
-
470
- **🏗️ Architecture**: Wav2Vec2 self-supervised
471
- **🎯 Focus**: Massive multilingual coverage
472
- **📊 Training**: 500K hours, 1400+ languages
473
- **⚡ Inference**: Requires task-specific fine-tuning
474
- **🎭 Use Cases**:
475
- - Research in multilingual ASR
476
- - Low-resource language support
477
- - Cross-lingual applications
478
- - Base model for fine-tuning
479
-
480
- **✅ Strengths**:
481
- - Unprecedented language coverage
482
- - Strong foundation model
483
- - Excellent for rare languages
484
- """)
485
-
486
- # Performance Analysis Tab
487
- with gr.TabItem("📈 Performance Deep Dive"):
488
- gr.Markdown("""
489
- # 🔍 Detailed Performance Analysis
490
-
491
- ## 📊 Understanding ASR Metrics
492
- """)
493
-
494
- with gr.Row():
495
- with gr.Column():
496
- gr.Markdown("""
497
- ### 📉 Word Error Rate (WER)
498
-
499
- **Formula**: `(S + D + I) / N × 100%`
500
- - **S**: Substitutions
501
- - **D**: Deletions
502
- - **I**: Insertions
503
- - **N**: Total words in reference
504
-
505
- **Interpretation**:
506
- - **< 5%**: Excellent
507
- - **5-15%**: Good
508
- - **15-30%**: Fair
509
- - **> 30%**: Poor
510
- """)
511
-
512
- with gr.Column():
513
- gr.Markdown("""
514
- ### 🔤 Character Error Rate (CER)
515
-
516
- **Formula**: Same as WER but at character level
517
-
518
- **Why CER matters**:
519
- - Better for morphologically rich languages
520
- - Captures partial word recognition
521
- - Useful for downstream NLP tasks
522
- - More granular error analysis
523
-
524
- **Typical Range**: Usually lower than WER
525
- """)
526
-
527
- with gr.Column():
528
- gr.Markdown("""
529
- ### ⚡ Real-Time Factor (RTF)
530
-
531
- **Formula**: `Processing Time / Audio Duration`
532
-
533
- **Interpretation**:
534
- - **RTF < 1.0**: ⚡ Faster than real-time
535
- - **RTF = 1.0**: 🎯 Real-time processing
536
- - **RTF > 1.0**: 🐌 Slower than real-time
537
-
538
- **Production Requirements**:
539
- - Live applications: RTF < 0.3
540
- - Batch processing: RTF < 1.0 acceptable
541
- """)
542
-
543
- gr.Markdown("""
544
- ## 🏆 Language-Specific Performance Champions
545
-
546
- | Language | 🥇 Best Model | WER Score | 🎯 Insights |
547
- |----------|-------------|-----------|-----------|
548
- | **Hindi** | AudioX | 12.14% | Strongest performance, most data available |
549
- | **Gujarati** | ElevenLabs | 17.96% | Close race with AudioX (18.66%) |
550
- | **Marathi** | ElevenLabs | 16.51% | Competitive performance across models |
551
- | **Tamil** | AudioX | 21.79% | Dravidian language complexity handled well |
552
- | **Telugu** | AudioX | 24.63% | Challenging agglutinative morphology |
553
- | **Kannada** | AudioX | 17.61% | Consistent South Indian performance |
554
- | **Malayalam** | AudioX | 26.92% | Most challenging across all models |
555
-
556
- ### 🔍 Key Observations:
557
-
558
- 1. **AudioX Dominance**: Wins in 6 out of 7 languages
559
- 2. **Language Difficulty**: Malayalam > Telugu > Tamil (Dravidian complexity)
560
- 3. **Commercial Gap**: 10-15% WER difference vs specialized models
561
- 4. **Regional Patterns**: North Indian languages generally perform better
562
- 5. **Model Specialization**: Purpose-built models significantly outperform generic ones
563
- """)
564
-
565
- # Usage Guidelines Tab
566
- with gr.TabItem("📖 Usage Guidelines"):
567
- gr.Markdown("""
568
- # 🚀 Model Selection Guide
569
-
570
- ## 🎯 Which Model Should You Choose?
571
- """)
572
-
573
- with gr.Row():
574
- with gr.Column():
575
- gr.Markdown("""
576
- ### 🏆 Choose AudioX When:
577
-
578
- ✅ **Production Applications**
579
- ✅ **Highest Accuracy Requirements**
580
- ✅ **North/South Indian Languages**
581
- ✅ **Real-time Processing**
582
- ✅ **Commercial Deployment**
583
- ✅ **Healthcare/Medical Domain**
584
-
585
- **Best For**: Voice assistants, transcription services, customer support
586
- """)
587
-
588
- with gr.Column():
589
- gr.Markdown("""
590
- ### 🎓 Choose IndicConformer When:
591
-
592
- ✅ **Research & Development**
593
- ✅ **Open Source Requirements**
594
- ✅ **All 22 Indian Languages**
595
- ✅ **Educational Projects**
596
- ✅ **Custom Fine-tuning**
597
- ✅ **Experimental Work**
598
-
599
- **Best For**: Academic research, prototyping, learning
600
- """)
601
-
602
- with gr.Column():
603
- gr.Markdown("""
604
- ### 🌍 Choose Facebook MMS When:
605
-
606
- ✅ **Rare/Low-resource Languages**
607
- ✅ **Multilingual Applications**
608
- ✅ **Transfer Learning Base**
609
- ✅ **Research in Multilingual ASR**
610
- ✅ **Cross-lingual Studies**
611
- ✅ **Foundation Model Needs**
612
-
613
- **Best For**: Research, rare languages, base model
614
- """)
615
-
616
- gr.Markdown("""
617
- ## 🛠️ Implementation Tips
618
-
619
- ### 📋 Pre-processing Recommendations:
620
- - **Sample Rate**: Ensure 16kHz for all models
621
- - **Audio Format**: WAV preferred over compressed formats
622
- - **Noise Reduction**: Apply basic denoising for better results
623
- - **Normalization**: Audio amplitude normalization recommended
624
-
625
- ### ⚡ Performance Optimization:
626
- - **GPU Usage**: Significant speedup with CUDA-enabled devices
627
- - **Batch Processing**: Process multiple files together when possible
628
- - **Model Caching**: Keep models loaded in memory for repeated use
629
- - **Quantization**: Consider model quantization for deployment
630
-
631
- ### 🎯 Accuracy Improvement:
632
- - **Domain Adaptation**: Fine-tune on domain-specific data when possible
633
- - **Language Models**: Integrate external LMs for better word-level accuracy
634
- - **Post-processing**: Apply spelling correction and grammar checking
635
- - **Ensemble Methods**: Combine multiple models for critical applications
636
- """)
637
-
638
- # Event handlers
639
- transcribe_button.click(
640
  fn=transcribe_audio,
641
- inputs=[audio_input, model_selector, reference_input],
642
- outputs=[transcription_output, model_info_output, wer_output, cer_output, rtf_output],
643
- show_progress=True
 
 
 
 
 
 
 
 
 
 
 
644
  )
645
 
646
- # Launch configuration
647
  if __name__ == "__main__":
648
- demo.launch(
649
- share=True,
650
- server_name="0.0.0.0",
651
- server_port=7860,
652
- show_error=True
653
- )
 
1
  import gradio as gr
2
  import torch
3
  import torchaudio
4
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, AutoModelForCTC
 
 
 
 
 
 
 
5
  import librosa
6
+ import numpy as np
7
+ from jiwer import wer, cer
8
  import time
 
 
 
 
 
9
 
10
  # Model configurations
11
+ MODEL_CONFIGS = {
12
+ "AudioX-North (Jivi AI)": {
13
+ "repo": "jiviai/audioX-north-v1",
14
+ "model_type": "seq2seq",
15
+ "description": "Supports Hindi, Gujarati, Marathi"
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  },
17
+ "IndicConformer (AI4Bharat)": {
18
+ "repo": "ai4bharat/indic-conformer-600m-multilingual",
19
+ "model_type": "ctc_rnnt",
20
+ "description": "Supports 22 Indian languages"
 
 
 
 
21
  },
22
+ "MMS (Facebook)": {
23
+ "repo": "facebook/mms-1b",
24
+ "model_type": "ctc",
25
+ "description": "Supports over 1,400 languages (fine-tuning recommended)"
 
 
 
 
26
  }
27
  }
28
 
29
+ # Load model and processor
30
+ def load_model_and_processor(model_name):
31
+ config = MODEL_CONFIGS[model_name]
32
+ repo = config["repo"]
33
+ model_type = config["model_type"]
34
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  try:
36
+ processor = AutoProcessor.from_pretrained(repo)
37
+ if model_type == "seq2seq":
38
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(repo)
39
+ else: # ctc or ctc_rnnt
40
+ model = AutoModelForCTC.from_pretrained(repo)
41
+ return model, processor, model_type
 
 
 
42
  except Exception as e:
43
+ return None, None, f"Error loading model: {str(e)}"
44
 
45
+ # Compute metrics (WER, CER, RTF)
46
+ def compute_metrics(reference, hypothesis, audio_duration):
47
+ if not reference or not hypothesis:
48
+ return None, None, None
49
  try:
50
+ wer_score = wer(reference, hypothesis)
51
+ cer_score = cer(reference, hypothesis)
52
+ rtf = audio_duration / time.time() # Simplified; actual RTF needs processing time
53
+ return wer_score, cer_score, rtf
54
+ except Exception as e:
55
+ return None, None, f"Error computing metrics: {str(e)}"
 
 
 
 
56
 
57
+ def transcribe_audio(audio_file, model_name, reference_text=""):
58
+ if not audio_file:
59
+ return "Please upload an audio file.", None, None, None
 
 
 
 
60
 
61
+ # Load model and processor
62
+ model, processor, model_type = load_model_and_processor(model_name)
63
+ if isinstance(model_type, str) and model_type.startswith("Error"):
64
+ return model_type, None, None, None
65
 
66
  try:
67
+ # Load and preprocess audio
68
+ audio, sr = librosa.load(audio_file, sr=16000)
 
 
 
69
  audio_duration = len(audio) / sr
70
 
71
+ # Process audio
72
+ inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
73
+ input_features = inputs["input_features"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ # Measure processing time for RTF
76
+ start_time = time.time()
77
+ with torch.no_grad():
78
+ if model_type == "seq2seq":
79
+ outputs = model.generate(input_features)
80
+ else: # ctc or ctc_rnnt
81
+ outputs = model(input_features).logits
82
+ outputs = torch.argmax(outputs, dim=-1)
83
+
84
+ # Decode transcription
85
+ transcription = processor.batch_decode(outputs, skip_special_tokens=True)[0]
86
+
87
+ # Compute metrics if reference text is provided
88
+ wer_score, cer_score, rtf = None, None, None
89
+ if reference_text:
90
+ wer_score, cer_score, rtf_error = compute_metrics(reference_text, transcription, audio_duration)
91
+ if isinstance(rtf_error, str):
92
+ return transcription, wer_score, cer_score, rtf_error
93
+ rtf = (time.time() - start_time) / audio_duration # Actual RTF
94
+
95
+ return transcription, wer_score, cer_score, rtf
96
  except Exception as e:
97
+ return f"Error during transcription: {str(e)}", None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ # Gradio interface
100
+ def create_interface():
101
+ model_choices = list(MODEL_CONFIGS.keys())
102
+ return gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  fn=transcribe_audio,
104
+ inputs=[
105
+ gr.Audio(type="filepath", label="Upload Audio File (16kHz recommended)"),
106
+ gr.Dropdown(choices=model_choices, label="Select Model", value=model_choices[0]),
107
+ gr.Textbox(label="Reference Text (Optional for WER/CER)", placeholder="Enter ground truth text here")
108
+ ],
109
+ outputs=[
110
+ gr.Textbox(label="Transcription"),
111
+ gr.Textbox(label="WER"),
112
+ gr.Textbox(label="CER"),
113
+ gr.Textbox(label="RTF")
114
+ ],
115
+ title="Multilingual Speech-to-Text with Metrics",
116
+ description="Upload an audio file, select a model, and optionally provide reference text to compute WER, CER, and RTF.",
117
+ allow_flagging="never"
118
  )
119
 
 
120
  if __name__ == "__main__":
121
+ iface = create_interface()
122
+ iface.launch()