Spaces:
Sleeping
Sleeping
from typing import Any, Dict | |
import gradio as gr | |
import librosa | |
import numpy as np | |
import torch | |
from transformers import WavLMForSequenceClassification | |
def feature_extract_simple( | |
wav, | |
sr=16_000, | |
win_len=15.0, | |
win_stride=15.0, | |
do_normalize=False, | |
) -> np.ndarray: | |
"""Simple feature extraction for WavLM. | |
Parameters | |
---------- | |
wav : str or array-like | |
path to the wav file, or array-like | |
sr : int, optional | |
sample rate, by default 16_000 | |
win_len : float, optional | |
window length, by default 15.0 | |
win_stride : float, optional | |
window stride, by default 15.0 | |
do_normalize: bool, optional | |
whether to normalize the input, by default False. | |
Returns | |
------- | |
np.ndarray | |
batched input to WavLM | |
""" | |
if type(wav) == str: | |
signal, _ = librosa.core.load(wav, sr=sr) | |
else: | |
try: | |
signal = np.array(wav).squeeze() | |
except Exception as e: | |
print(e) | |
raise RuntimeError | |
batched_input = [] | |
stride = int(win_stride * sr) | |
l = int(win_len * sr) | |
if len(signal) / sr > win_len: | |
for i in range(0, len(signal), stride): | |
if i + int(win_len * sr) > len(signal): | |
# padding the last chunk to make it the same length as others | |
chunked = np.pad(signal[i:], (0, l - len(signal[i:]))) | |
else: | |
chunked = signal[i : i + l] | |
if do_normalize: | |
chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7) | |
batched_input.append(chunked) | |
if i + int(win_len * sr) > len(signal): | |
break | |
else: | |
if do_normalize: | |
signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7) | |
batched_input.append(signal) | |
return np.stack(batched_input) # [N, T] | |
def infer(model, inputs) -> torch.Tensor: | |
output = model(inputs) | |
probs = torch.sigmoid(torch.Tensor(output.logits)) | |
return probs | |
def predict(audio_file) -> Dict[str, Any]: | |
if audio_file is None: | |
return {"No prediction available": 0.0} | |
try: | |
input_np = feature_extract_simple(audio_file, sr=16000, do_normalize=True) | |
input_pt = torch.Tensor(input_np) | |
probs = infer(model, input_pt) | |
probs_list = probs.reshape(-1, len(labels)).detach().tolist() | |
# Create a results dictionary | |
if len(probs_list) > 0: | |
first_segment_probs = probs_list[0] | |
results = { | |
label: float(prob) for label, prob in zip(labels, first_segment_probs) | |
} | |
# If there are multiple segments, include that information in the results | |
if len(probs_list) > 1: | |
results["Note"] = ( | |
f"Audio contains {len(probs_list)} segments. Showing first segment only." | |
) | |
else: | |
results = {"Error": "No segments detected in audio"} | |
# Sort by confidence score | |
sorted_results = dict(sorted(results.items(), key=lambda x: x[1], reverse=True)) | |
return sorted_results | |
except Exception as e: | |
return {"Error": str(e)} | |
if __name__ == "__main__": | |
model_path = "Roblox/voice-safety-classifier-v2" | |
labels = [ | |
"Discrimination", | |
"Harassment", | |
"Sexual", | |
"IllegalAndRegulated", | |
"DatingAndRomantic", | |
"Profanity", | |
] | |
model = WavLMForSequenceClassification.from_pretrained( | |
model_path, num_labels=len(labels) | |
) | |
model.eval() | |
demo = gr.Interface( | |
fn=predict, | |
inputs=gr.Audio(type="filepath", label="Upload or record audio"), | |
outputs=gr.Label(num_top_classes=6, label="Classification Results"), | |
title="Voice Safety Classifier", | |
description="""This app uses the Roblox Voice Safety Classifier v2 model to identify potentially unsafe content in audio. | |
Upload or record an audio file to get started. The model classifies audio into categories including Discrimination, | |
Harassment, Sexual, IllegalAndRegulated, DatingAndRomantic, and Profanity. | |
The model processes audio in 15-second chunks and returns probability scores for each category.""", | |
examples=[], | |
flagging_mode="never", | |
) | |
demo.launch() | |