Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import librosa
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoModelForAudioClassification, ASTFeatureExtractor
|
| 7 |
+
import random
|
| 8 |
+
import tempfile
|
| 9 |
+
|
| 10 |
+
# Attempt to load models with try-except block to handle errors
|
| 11 |
+
try:
|
| 12 |
+
wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
|
| 13 |
+
wav2vec_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
| 14 |
+
except Exception as e:
|
| 15 |
+
print(f"Error loading Wav2Vec 2.0 models: {e}")
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
model = AutoModelForAudioClassification.from_pretrained("./")
|
| 19 |
+
feature_extractor = ASTFeatureExtractor.from_pretrained("./")
|
| 20 |
+
except Exception as e:
|
| 21 |
+
print(f"Error loading custom models: {e}")
|
| 22 |
+
|
| 23 |
+
def plot_waveform(waveform, sr):
|
| 24 |
+
plt.figure(figsize=(12, 4))
|
| 25 |
+
plt.title('Waveform')
|
| 26 |
+
plt.ylabel('Amplitude')
|
| 27 |
+
plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
|
| 28 |
+
plt.xlabel('Time (s)')
|
| 29 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
|
| 30 |
+
plt.savefig(temp_file.name)
|
| 31 |
+
plt.close()
|
| 32 |
+
return temp_file.name
|
| 33 |
+
|
| 34 |
+
def plot_spectrogram(waveform, sr):
|
| 35 |
+
S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
|
| 36 |
+
S_DB = librosa.power_to_db(S, ref=np.max)
|
| 37 |
+
plt.figure(figsize=(12, 6))
|
| 38 |
+
librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel', cmap='inferno')
|
| 39 |
+
plt.title('Mel Spectrogram')
|
| 40 |
+
plt.colorbar(format='%+2.0f dB')
|
| 41 |
+
plt.tight_layout()
|
| 42 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
|
| 43 |
+
plt.savefig(temp_file.name)
|
| 44 |
+
plt.close()
|
| 45 |
+
return temp_file.name
|
| 46 |
+
|
| 47 |
+
def custom_feature_extraction(audio, sr=16000, target_length=1024):
|
| 48 |
+
features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
|
| 49 |
+
return features.input_values
|
| 50 |
+
|
| 51 |
+
def apply_time_shift(waveform, max_shift_fraction=0.1):
|
| 52 |
+
shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
|
| 53 |
+
return np.roll(waveform, shift)
|
| 54 |
+
|
| 55 |
+
def transcribe_audio(audio_file_path):
|
| 56 |
+
waveform, _ = librosa.load(audio_file_path, sr=wav2vec_processor.feature_extractor.sampling_rate, mono=True)
|
| 57 |
+
input_values = wav2vec_processor(waveform, return_tensors="pt", padding="longest").input_values
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
logits = wav2vec_model(input_values).logits
|
| 60 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
| 61 |
+
transcription = wav2vec_processor.batch_decode(predicted_ids)
|
| 62 |
+
return transcription
|
| 63 |
+
|
| 64 |
+
def predict_voice(audio_file_path):
|
| 65 |
+
try:
|
| 66 |
+
transcription = transcribe_audio(audio_file_path)
|
| 67 |
+
|
| 68 |
+
waveform, sample_rate = librosa.load(audio_file_path, sr=feature_extractor.sampling_rate, mono=True)
|
| 69 |
+
augmented_waveform = apply_time_shift(waveform)
|
| 70 |
+
|
| 71 |
+
original_features = custom_feature_extraction(waveform, sr=sample_rate)
|
| 72 |
+
augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate)
|
| 73 |
+
|
| 74 |
+
with torch.no_grad():
|
| 75 |
+
outputs_original = model(original_features)
|
| 76 |
+
outputs_augmented = model(augmented_features)
|
| 77 |
+
|
| 78 |
+
logits = (outputs_original.logits + outputs_augmented.logits) / 2
|
| 79 |
+
predicted_index = logits.argmax()
|
| 80 |
+
original_label = model.config.id2label[predicted_index.item()]
|
| 81 |
+
confidence = torch.softmax(logits, dim=1).max().item() * 100
|
| 82 |
+
|
| 83 |
+
label_mapping = {
|
| 84 |
+
"Spoof": "AI-generated Clone",
|
| 85 |
+
"Bonafide": "Real Human Voice"
|
| 86 |
+
}
|
| 87 |
+
new_label = label_mapping.get(original_label, "Unknown")
|
| 88 |
+
|
| 89 |
+
waveform_plot = plot_waveform(waveform, sample_rate)
|
| 90 |
+
spectrogram_plot = plot_spectrogram(waveform, sample_rate)
|
| 91 |
+
|
| 92 |
+
return (
|
| 93 |
+
f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.",
|
| 94 |
+
waveform_plot,
|
| 95 |
+
spectrogram_plot,
|
| 96 |
+
transcription[0] # Assuming transcription returns a list with a single string
|
| 97 |
+
)
|
| 98 |
+
except Exception as e:
|
| 99 |
+
return f"Error during processing: {e}", None, None, ""
|
| 100 |
+
|
| 101 |
+
with gr.Blocks(css="style.css") as demo:
|
| 102 |
+
gr.Markdown("## Voice Clone Detection")
|
| 103 |
+
gr.Markdown("Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results.")
|
| 104 |
+
|
| 105 |
+
with gr.Row():
|
| 106 |
+
audio_input = gr.Audio(label="Upload Audio File", type="filepath")
|
| 107 |
+
|
| 108 |
+
with gr.Row():
|
| 109 |
+
prediction_output = gr.Textbox(label="Prediction")
|
| 110 |
+
transcription_output = gr.Textbox(label="Transcription")
|
| 111 |
+
waveform_output = gr.Image(label="Waveform")
|
| 112 |
+
spectrogram_output = gr.Image(label="Spectrogram")
|
| 113 |
+
|
| 114 |
+
detect_button = gr.Button("Detect Voice Clone")
|
| 115 |
+
detect_button.click(
|
| 116 |
+
fn=predict_voice,
|
| 117 |
+
inputs=[audio_input],
|
| 118 |
+
outputs=[prediction_output, waveform_output, spectrogram_output, transcription_output]
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Launch the interface
|
| 122 |
+
demo.launch()
|
| 123 |
+
|