Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import librosa | |
| import numpy as np | |
| import torch | |
| import matplotlib.pyplot as plt | |
| from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoModelForAudioClassification, ASTFeatureExtractor | |
| import random | |
| import tempfile | |
| # Attempt to load models with try-except block to handle errors | |
| try: | |
| wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") | |
| wav2vec_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") | |
| except Exception as e: | |
| print(f"Error loading Wav2Vec 2.0 models: {e}") | |
| try: | |
| model = AutoModelForAudioClassification.from_pretrained("./") | |
| feature_extractor = ASTFeatureExtractor.from_pretrained("./") | |
| except Exception as e: | |
| print(f"Error loading custom models: {e}") | |
| def plot_waveform(waveform, sr): | |
| plt.figure(figsize=(12, 4)) | |
| plt.title('Waveform') | |
| plt.ylabel('Amplitude') | |
| plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform) | |
| plt.xlabel('Time (s)') | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./') | |
| plt.savefig(temp_file.name) | |
| plt.close() | |
| return temp_file.name | |
| def plot_spectrogram(waveform, sr): | |
| S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128) | |
| S_DB = librosa.power_to_db(S, ref=np.max) | |
| plt.figure(figsize=(12, 6)) | |
| librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel', cmap='inferno') | |
| plt.title('Mel Spectrogram') | |
| plt.colorbar(format='%+2.0f dB') | |
| plt.tight_layout() | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./') | |
| plt.savefig(temp_file.name) | |
| plt.close() | |
| return temp_file.name | |
| def custom_feature_extraction(audio, sr=16000, target_length=1024): | |
| features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length) | |
| return features.input_values | |
| def apply_time_shift(waveform, max_shift_fraction=0.1): | |
| shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform))) | |
| return np.roll(waveform, shift) | |
| def transcribe_audio(audio_file_path): | |
| waveform, _ = librosa.load(audio_file_path, sr=wav2vec_processor.feature_extractor.sampling_rate, mono=True) | |
| input_values = wav2vec_processor(waveform, return_tensors="pt", padding="longest").input_values | |
| with torch.no_grad(): | |
| logits = wav2vec_model(input_values).logits | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = wav2vec_processor.batch_decode(predicted_ids) | |
| return transcription | |
| def predict_voice(audio_file_path): | |
| try: | |
| transcription = transcribe_audio(audio_file_path) | |
| waveform, sample_rate = librosa.load(audio_file_path, sr=feature_extractor.sampling_rate, mono=True) | |
| augmented_waveform = apply_time_shift(waveform) | |
| original_features = custom_feature_extraction(waveform, sr=sample_rate) | |
| augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate) | |
| with torch.no_grad(): | |
| outputs_original = model(original_features) | |
| outputs_augmented = model(augmented_features) | |
| logits = (outputs_original.logits + outputs_augmented.logits) / 2 | |
| predicted_index = logits.argmax() | |
| original_label = model.config.id2label[predicted_index.item()] | |
| confidence = torch.softmax(logits, dim=1).max().item() * 100 | |
| label_mapping = { | |
| "Spoof": "AI-generated Clone", | |
| "Bonafide": "Real Human Voice" | |
| } | |
| new_label = label_mapping.get(original_label, "Unknown") | |
| waveform_plot = plot_waveform(waveform, sample_rate) | |
| spectrogram_plot = plot_spectrogram(waveform, sample_rate) | |
| return ( | |
| f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.", | |
| waveform_plot, | |
| spectrogram_plot, | |
| transcription[0] # Assuming transcription returns a list with a single string | |
| ) | |
| except Exception as e: | |
| return f"Error during processing: {e}", None, None, "" | |
| with gr.Blocks(css="style.css") as demo: | |
| gr.Markdown("## Voice Clone Detection") | |
| gr.Markdown("Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results.") | |
| with gr.Row(): | |
| audio_input = gr.Audio(label="Upload Audio File", type="filepath") | |
| with gr.Row(): | |
| prediction_output = gr.Textbox(label="Prediction") | |
| transcription_output = gr.Textbox(label="Transcription") | |
| waveform_output = gr.Image(label="Waveform") | |
| spectrogram_output = gr.Image(label="Spectrogram") | |
| detect_button = gr.Button("Detect Voice Clone") | |
| detect_button.click( | |
| fn=predict_voice, | |
| inputs=[audio_input], | |
| outputs=[prediction_output, waveform_output, spectrogram_output, transcription_output] | |
| ) | |
| # Launch the interface | |
| demo.launch() | |