Kabatubare's picture
Create app.py
afcca30 verified
import gradio as gr
import librosa
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoModelForAudioClassification, ASTFeatureExtractor
import random
import tempfile
# Attempt to load models with try-except block to handle errors
try:
wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
wav2vec_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
except Exception as e:
print(f"Error loading Wav2Vec 2.0 models: {e}")
try:
model = AutoModelForAudioClassification.from_pretrained("./")
feature_extractor = ASTFeatureExtractor.from_pretrained("./")
except Exception as e:
print(f"Error loading custom models: {e}")
def plot_waveform(waveform, sr):
plt.figure(figsize=(12, 4))
plt.title('Waveform')
plt.ylabel('Amplitude')
plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
plt.xlabel('Time (s)')
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
plt.savefig(temp_file.name)
plt.close()
return temp_file.name
def plot_spectrogram(waveform, sr):
S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
S_DB = librosa.power_to_db(S, ref=np.max)
plt.figure(figsize=(12, 6))
librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel', cmap='inferno')
plt.title('Mel Spectrogram')
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
plt.savefig(temp_file.name)
plt.close()
return temp_file.name
def custom_feature_extraction(audio, sr=16000, target_length=1024):
features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
return features.input_values
def apply_time_shift(waveform, max_shift_fraction=0.1):
shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
return np.roll(waveform, shift)
def transcribe_audio(audio_file_path):
waveform, _ = librosa.load(audio_file_path, sr=wav2vec_processor.feature_extractor.sampling_rate, mono=True)
input_values = wav2vec_processor(waveform, return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = wav2vec_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = wav2vec_processor.batch_decode(predicted_ids)
return transcription
def predict_voice(audio_file_path):
try:
transcription = transcribe_audio(audio_file_path)
waveform, sample_rate = librosa.load(audio_file_path, sr=feature_extractor.sampling_rate, mono=True)
augmented_waveform = apply_time_shift(waveform)
original_features = custom_feature_extraction(waveform, sr=sample_rate)
augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate)
with torch.no_grad():
outputs_original = model(original_features)
outputs_augmented = model(augmented_features)
logits = (outputs_original.logits + outputs_augmented.logits) / 2
predicted_index = logits.argmax()
original_label = model.config.id2label[predicted_index.item()]
confidence = torch.softmax(logits, dim=1).max().item() * 100
label_mapping = {
"Spoof": "AI-generated Clone",
"Bonafide": "Real Human Voice"
}
new_label = label_mapping.get(original_label, "Unknown")
waveform_plot = plot_waveform(waveform, sample_rate)
spectrogram_plot = plot_spectrogram(waveform, sample_rate)
return (
f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.",
waveform_plot,
spectrogram_plot,
transcription[0] # Assuming transcription returns a list with a single string
)
except Exception as e:
return f"Error during processing: {e}", None, None, ""
with gr.Blocks(css="style.css") as demo:
gr.Markdown("## Voice Clone Detection")
gr.Markdown("Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results.")
with gr.Row():
audio_input = gr.Audio(label="Upload Audio File", type="filepath")
with gr.Row():
prediction_output = gr.Textbox(label="Prediction")
transcription_output = gr.Textbox(label="Transcription")
waveform_output = gr.Image(label="Waveform")
spectrogram_output = gr.Image(label="Spectrogram")
detect_button = gr.Button("Detect Voice Clone")
detect_button.click(
fn=predict_voice,
inputs=[audio_input],
outputs=[prediction_output, waveform_output, spectrogram_output, transcription_output]
)
# Launch the interface
demo.launch()