Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,38 +10,23 @@ text_model = AutoModelForSequenceClassification.from_pretrained("tae898/emoberta
|
|
| 10 |
text_model.gradient_checkpointing_enable()
|
| 11 |
|
| 12 |
|
| 13 |
-
def
|
| 14 |
-
|
| 15 |
-
if audio_file is not None:
|
| 16 |
-
waveform, sr = torchaudio.load(audio_file)
|
| 17 |
-
preds = speech_classifier(waveform.squeeze().numpy(), sampling_rate=sr, top_k=3)
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
}
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
inputs = text_tokenizer(text, return_tensors="pt", truncation=True)
|
| 28 |
with torch.no_grad():
|
| 29 |
outputs = text_model(**inputs)
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
return {
|
| 35 |
-
"Detected Emotion": text_model.config.id2label[label_id],
|
| 36 |
-
"Top Predictions": {
|
| 37 |
-
text_model.config.id2label[i]: round(p, 3)
|
| 38 |
-
for i, p in enumerate(probs[0].tolist())
|
| 39 |
-
},
|
| 40 |
-
"Source": "Text"
|
| 41 |
-
}
|
| 42 |
-
|
| 43 |
-
return {"Error": "Please provide audio or text input."}
|
| 44 |
-
|
| 45 |
# Building the UI
|
| 46 |
gradio_ui = gr.Interface(
|
| 47 |
fn=gradio_combined,
|
|
|
|
| 10 |
text_model.gradient_checkpointing_enable()
|
| 11 |
|
| 12 |
|
| 13 |
+
def predict_emotion(audio, text):
|
| 14 |
+
results = {}
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
if audio is not None:
|
| 17 |
+
waveform, sr = torchaudio.load(audio)
|
| 18 |
+
preds = speech_classifier(waveform.squeeze().numpy(), sampling_rate=sr, top_k=3)
|
| 19 |
+
results["audio_emotion"] = preds[0]["label"]
|
|
|
|
| 20 |
|
| 21 |
+
if text is not None and text.strip() != "":
|
| 22 |
+
inputs = text_tokenizer(text, return_tensors="pt")
|
|
|
|
| 23 |
with torch.no_grad():
|
| 24 |
outputs = text_model(**inputs)
|
| 25 |
+
emotion = text_model.config.id2label[torch.argmax(outputs.logits)]
|
| 26 |
+
results["text_emotion"] = emotion
|
| 27 |
|
| 28 |
+
return results
|
| 29 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# Building the UI
|
| 31 |
gradio_ui = gr.Interface(
|
| 32 |
fn=gradio_combined,
|