Kabatubare commited on
Commit
afcca30
·
verified ·
1 Parent(s): 1f80919

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import librosa
3
+ import numpy as np
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoModelForAudioClassification, ASTFeatureExtractor
7
+ import random
8
+ import tempfile
9
+
10
+ # Attempt to load models with try-except block to handle errors
11
+ try:
12
+ wav2vec_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
13
+ wav2vec_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
14
+ except Exception as e:
15
+ print(f"Error loading Wav2Vec 2.0 models: {e}")
16
+
17
+ try:
18
+ model = AutoModelForAudioClassification.from_pretrained("./")
19
+ feature_extractor = ASTFeatureExtractor.from_pretrained("./")
20
+ except Exception as e:
21
+ print(f"Error loading custom models: {e}")
22
+
23
+ def plot_waveform(waveform, sr):
24
+ plt.figure(figsize=(12, 4))
25
+ plt.title('Waveform')
26
+ plt.ylabel('Amplitude')
27
+ plt.plot(np.linspace(0, len(waveform) / sr, len(waveform)), waveform)
28
+ plt.xlabel('Time (s)')
29
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
30
+ plt.savefig(temp_file.name)
31
+ plt.close()
32
+ return temp_file.name
33
+
34
+ def plot_spectrogram(waveform, sr):
35
+ S = librosa.feature.melspectrogram(y=waveform, sr=sr, n_mels=128)
36
+ S_DB = librosa.power_to_db(S, ref=np.max)
37
+ plt.figure(figsize=(12, 6))
38
+ librosa.display.specshow(S_DB, sr=sr, x_axis='time', y_axis='mel', cmap='inferno')
39
+ plt.title('Mel Spectrogram')
40
+ plt.colorbar(format='%+2.0f dB')
41
+ plt.tight_layout()
42
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir='./')
43
+ plt.savefig(temp_file.name)
44
+ plt.close()
45
+ return temp_file.name
46
+
47
+ def custom_feature_extraction(audio, sr=16000, target_length=1024):
48
+ features = feature_extractor(audio, sampling_rate=sr, return_tensors="pt", padding="max_length", max_length=target_length)
49
+ return features.input_values
50
+
51
+ def apply_time_shift(waveform, max_shift_fraction=0.1):
52
+ shift = random.randint(-int(max_shift_fraction * len(waveform)), int(max_shift_fraction * len(waveform)))
53
+ return np.roll(waveform, shift)
54
+
55
+ def transcribe_audio(audio_file_path):
56
+ waveform, _ = librosa.load(audio_file_path, sr=wav2vec_processor.feature_extractor.sampling_rate, mono=True)
57
+ input_values = wav2vec_processor(waveform, return_tensors="pt", padding="longest").input_values
58
+ with torch.no_grad():
59
+ logits = wav2vec_model(input_values).logits
60
+ predicted_ids = torch.argmax(logits, dim=-1)
61
+ transcription = wav2vec_processor.batch_decode(predicted_ids)
62
+ return transcription
63
+
64
+ def predict_voice(audio_file_path):
65
+ try:
66
+ transcription = transcribe_audio(audio_file_path)
67
+
68
+ waveform, sample_rate = librosa.load(audio_file_path, sr=feature_extractor.sampling_rate, mono=True)
69
+ augmented_waveform = apply_time_shift(waveform)
70
+
71
+ original_features = custom_feature_extraction(waveform, sr=sample_rate)
72
+ augmented_features = custom_feature_extraction(augmented_waveform, sr=sample_rate)
73
+
74
+ with torch.no_grad():
75
+ outputs_original = model(original_features)
76
+ outputs_augmented = model(augmented_features)
77
+
78
+ logits = (outputs_original.logits + outputs_augmented.logits) / 2
79
+ predicted_index = logits.argmax()
80
+ original_label = model.config.id2label[predicted_index.item()]
81
+ confidence = torch.softmax(logits, dim=1).max().item() * 100
82
+
83
+ label_mapping = {
84
+ "Spoof": "AI-generated Clone",
85
+ "Bonafide": "Real Human Voice"
86
+ }
87
+ new_label = label_mapping.get(original_label, "Unknown")
88
+
89
+ waveform_plot = plot_waveform(waveform, sample_rate)
90
+ spectrogram_plot = plot_spectrogram(waveform, sample_rate)
91
+
92
+ return (
93
+ f"The voice is classified as '{new_label}' with a confidence of {confidence:.2f}%.",
94
+ waveform_plot,
95
+ spectrogram_plot,
96
+ transcription[0] # Assuming transcription returns a list with a single string
97
+ )
98
+ except Exception as e:
99
+ return f"Error during processing: {e}", None, None, ""
100
+
101
+ with gr.Blocks(css="style.css") as demo:
102
+ gr.Markdown("## Voice Clone Detection")
103
+ gr.Markdown("Detects whether a voice is real or an AI-generated clone. Upload an audio file to see the results.")
104
+
105
+ with gr.Row():
106
+ audio_input = gr.Audio(label="Upload Audio File", type="filepath")
107
+
108
+ with gr.Row():
109
+ prediction_output = gr.Textbox(label="Prediction")
110
+ transcription_output = gr.Textbox(label="Transcription")
111
+ waveform_output = gr.Image(label="Waveform")
112
+ spectrogram_output = gr.Image(label="Spectrogram")
113
+
114
+ detect_button = gr.Button("Detect Voice Clone")
115
+ detect_button.click(
116
+ fn=predict_voice,
117
+ inputs=[audio_input],
118
+ outputs=[prediction_output, waveform_output, spectrogram_output, transcription_output]
119
+ )
120
+
121
+ # Launch the interface
122
+ demo.launch()
123
+