import gradio as gr import torch import torchaudio import numpy as np import os import pandas as pd from datetime import timedelta from pathlib import Path from transformers import ( Wav2Vec2ForCTC, Wav2Vec2Processor, WhisperProcessor, WhisperForConditionalGeneration ) from pyannote.audio import Pipeline, Model, Inference from scipy.spatial.distance import cdist import torchaudio.transforms as T def ensure_16k(waveform, sr, target_sr=16000): """Ensure waveform is at 16kHz (mono).""" if waveform.ndim > 1 and waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) # convert stereo → mono if sr != target_sr: resampler = T.Resample(sr, target_sr) waveform = resampler(waveform) return waveform, target_sr # ------------------- Config ------------------- CACHE_DIR = "./models_cache" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[INFO] Using device: {DEVICE}") HF_TOKEN = os.getenv("HF_TOKEN") # STT model options MODEL_OPTIONS = { "v2 ( 15 hours training data)": "ganga4364/kr_wav2vec2_v2.137000", "v1 ( 9 hours training data) ": "ganga4364/mms_300_khentse_Rinpoche-Checkpoint-58000", "base": "openpecha/general_stt_base_model" } # Cache for STT models stt_cache = {} def load_stt_model(choice): if choice not in stt_cache: print(f"[INFO] Loading STT model: {choice}") model_name = MODEL_OPTIONS[choice] model = Wav2Vec2ForCTC.from_pretrained(model_name, cache_dir=CACHE_DIR).to(DEVICE) processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=CACHE_DIR) model.eval() stt_cache[choice] = (model, processor) return stt_cache[choice] # ------------------- Whisper Large v3 ------------------- print("[INFO] Loading Whisper Large V3 for other speakers...") whisper_model = WhisperForConditionalGeneration.from_pretrained( "openai/whisper-large-v3", cache_dir=CACHE_DIR ).to(DEVICE) whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3", cache_dir=CACHE_DIR) whisper_model.eval() def transcribe_with_whisper(waveform, sr): waveform, sr = ensure_16k(waveform, sr) # Ensure the waveform is long enough for the model if waveform.shape[1] < 400: return "" inputs = whisper_processor(waveform.squeeze(), sampling_rate=sr, return_tensors="pt") input_features = inputs.input_features.to(DEVICE) with torch.no_grad(): predicted_ids = whisper_model.generate(input_features) return whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] # ------------------- Pyannote ------------------- try: diarization_pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", token=HF_TOKEN, cache_dir=CACHE_DIR ).to(DEVICE) print("Pyannote diarization loaded") except Exception as e: diarization_pipeline = None print(f"[WARN] Pyannote diarization not available: {e}") # Embedding model for voice print embedding_model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM", cache_dir=CACHE_DIR) embedding_inference = Inference(embedding_model, window="whole") # ------------------- Helpers ------------------- MAX_SEGMENT_SEC = 15 def format_timestamp(seconds, format_type="srt"): td = timedelta(seconds=seconds) hours, remainder = divmod(td.seconds, 3600) minutes, seconds = divmod(remainder, 60) milliseconds = round(td.microseconds / 1000) if format_type == "srt": return f"{hours:02d}:{minutes:02d}:{seconds:02d},{milliseconds:03d}" else: return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{milliseconds:03d}" def create_subtitle_file(timestamps_with_text, output_path, format_type="srt"): with open(output_path, "w", encoding="utf-8") as f: if format_type == "vtt": f.write("WEBVTT\n\n") for i, (start, end, text, speaker) in enumerate(timestamps_with_text, 1): if format_type == "srt": f.write(f"{i}\n") f.write(f"{format_timestamp(start)} --> {format_timestamp(end)}\n") f.write(f"{speaker}: {text}\n\n") else: f.write(f"{format_timestamp(start, 'vtt')} --> {format_timestamp(end, 'vtt')}\n") f.write(f"{text}\n\n") def split_long_segment(start, end, max_length=MAX_SEGMENT_SEC): segments = [] total_duration = end - start if total_duration <= max_length: return [(start, end)] current = start while current < end: seg_end = min(current + max_length, end) segments.append((current, seg_end)) current = seg_end return segments def transcribe_segment(waveform, sr, model, processor): waveform, sr = ensure_16k(waveform, sr) # Ensure the waveform is long enough for the model if waveform.shape[1] < 400: # Heuristic value, might need adjustment return "" inputs = processor(waveform.squeeze(), sampling_rate=sr, return_tensors="pt", padding=True) inputs = {k: v.to(DEVICE) for k, v in inputs.items()} with torch.no_grad(): logits = model(**inputs).logits predicted_ids = torch.argmax(logits, dim=-1) return processor.decode(predicted_ids[0].cpu()) # ------------------- Speaker Identification ------------------- def identify_speaker(diarization_df, audio_path, voice_print_embedding, speaker_name, inference, threshold=0.6, n_segments=3): waveform, sr = torchaudio.load(audio_path) speaker_distances = {} for speaker in diarization_df['speaker'].unique(): sp_df = diarization_df[diarization_df['speaker'] == speaker].copy() sp_df['duration'] = sp_df['end'] - sp_df['start'] sp_df = sp_df.sort_values(by='duration', ascending=False).head(n_segments) distances = [] for _, row in sp_df.iterrows(): start, end = int(row['start']*sr), int(row['end']*sr) segment = waveform[:, start:end] seg_path = f"/tmp/{speaker}_{start}_{end}.wav" torchaudio.save(seg_path, segment, sr) try: seg_embedding = inference(seg_path) seg_embedding = np.atleast_2d(seg_embedding) dist = cdist(seg_embedding, voice_print_embedding, metric="cosine")[0, 0] distances.append(dist) except Exception as e: print(f"Error embedding segment {speaker} {row['start']}-{row['end']}: {e}") if distances: speaker_distances[speaker] = np.mean(distances) if not speaker_distances: return None, {}, diarization_df best_match = min(speaker_distances, key=speaker_distances.get) min_distance = speaker_distances[best_match] if min_distance <= threshold: mapping = {sp: speaker_name if sp == best_match else f"Other Speaker {i}" for i, sp in enumerate(speaker_distances.keys())} else: mapping = {sp: f"Speaker {i}" for i, sp in enumerate(speaker_distances.keys())} diarization_df['identified_speaker'] = diarization_df['speaker'].map(mapping) return best_match, mapping, diarization_df # ------------------- Main ------------------- def process_audio(model_choice, mode, voice_print_path, audio_path, speaker_name, threshold=0.6): # --- Full audio --- waveform, sample_rate = torchaudio.load(audio_path) waveform, sample_rate = ensure_16k(waveform, sample_rate) stt_model, stt_processor = load_stt_model(model_choice) # --- Voice print --- vp_waveform, vp_sr = torchaudio.load(voice_print_path) vp_waveform, vp_sr = ensure_16k(vp_waveform, vp_sr) # Save temp 16k voice print file for embedding tmp_vp = "/tmp/voice_print_16k.wav" torchaudio.save(tmp_vp, vp_waveform, vp_sr) voice_print_embedding = embedding_inference(tmp_vp) voice_print_embedding = np.atleast_2d(voice_print_embedding) results = [] if "Diarization" in mode: if diarization_pipeline is None: return "Pyannote diarization is not available.", None, None diarization = diarization_pipeline({"waveform": waveform, "sample_rate": sample_rate}) # Run diarization - pass audio file path directly for better compatibility #diarization = diarization_pipeline(audio_path) # Correct API for pyannote 3.1+ with DiarizeOutput data = [] # Check if we have the new API (DiarizeOutput with speaker_diarization attribute) if hasattr(diarization, 'speaker_diarization'): # New API (pyannote 3.1+) - iterate over speaker_diarization for turn, speaker in diarization.speaker_diarization: data.append({ "start": turn.start, "end": turn.end, "speaker": speaker }) elif hasattr(diarization, 'itertracks'): # Old API (pyannote < 3.1) - Annotation object for segment, track, speaker in diarization.itertracks(yield_label=True): data.append({ "start": segment.start, "end": segment.end, "speaker": speaker }) else: return "Unsupported pyannote.audio version. Please check the diarization output format.", None, None if not data: return "No speaker segments found in diarization.", None, None diarization_df = pd.DataFrame(data) # Always identify the target speaker _, mapping, diarization_df = identify_speaker( diarization_df, audio_path, voice_print_embedding, speaker_name, embedding_inference, threshold ) for _, row in diarization_df.iterrows(): for seg_start, seg_end in split_long_segment(row['start'], row['end']): seg_waveform = waveform[:, int(seg_start*sample_rate):int(seg_end*sample_rate)] if row['identified_speaker'] == speaker_name: # 🎯 Target speaker → use MMS model transcription = transcribe_segment(seg_waveform, sample_rate, stt_model, stt_processor) else: if mode == "Diarization (Target Speaker Only)": transcription = "" # skip other speakers else: # 👥 Other speakers → Whisper Large v3 transcription = transcribe_with_whisper(seg_waveform, sample_rate) results.append((seg_start, seg_end, transcription, row['identified_speaker'])) # Save subtitle files base_path = os.path.splitext(audio_path)[0] srt_path = f"{base_path}_identified.srt" vtt_path = f"{base_path}_identified.vtt" create_subtitle_file(results, srt_path, "srt") create_subtitle_file(results, vtt_path, "vtt") transcript_text = "\n".join([f"{sp}: {txt}" for (_, _, txt, sp) in results]) return transcript_text, srt_path, vtt_path # ------------------- Gradio ------------------- demo = gr.Interface( fn=process_audio, inputs=[ gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="base", label="Select STT Model"), gr.Radio( choices=["Diarization (Transcribe All)", "Diarization (Target Speaker Only)"], value="Diarization (Transcribe All)", label="Segmentation Method" ), gr.Audio(sources=["upload"], type="filepath", label="Voice Print Audio"), gr.Audio(sources=["upload"], type="filepath", label="Full Audio"), gr.Textbox(value="DKR", label="Speaker Name for Voice Print") ], outputs=[ gr.Textbox(label="Transcript"), gr.File(label="SRT File"), gr.File(label="WebVTT File") ], title="STT + Speaker Identification", description="Choose model, diarization mode, upload voice print and full audio, and label the known speaker. Target speaker → chosen model; other speakers → Whisper Large v3." ) if __name__ == "__main__": demo.launch(share=True)