LingoJr commited on
Commit
291ce94
Β·
verified Β·
1 Parent(s): 50e2b80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -27
app.py CHANGED
@@ -10,38 +10,23 @@ text_model = AutoModelForSequenceClassification.from_pretrained("tae898/emoberta
10
  text_model.gradient_checkpointing_enable()
11
 
12
 
13
- def gradio_combined(audio_file, text):
14
- # Case 1 β€” Audio provided
15
- if audio_file is not None:
16
- waveform, sr = torchaudio.load(audio_file)
17
- preds = speech_classifier(waveform.squeeze().numpy(), sampling_rate=sr, top_k=3)
18
 
19
- return {
20
- "Detected Emotion": preds[0]["label"],
21
- "Top Predictions": {p["label"]: round(p["score"], 3) for p in preds},
22
- "Source": "Audio"
23
- }
24
 
25
- # Case 2 β€” Text provided
26
- if text.strip() != "":
27
- inputs = text_tokenizer(text, return_tensors="pt", truncation=True)
28
  with torch.no_grad():
29
  outputs = text_model(**inputs)
 
 
30
 
31
- probs = torch.nn.functional.softmax(outputs.logits, dim=1)
32
- label_id = torch.argmax(probs).item()
33
-
34
- return {
35
- "Detected Emotion": text_model.config.id2label[label_id],
36
- "Top Predictions": {
37
- text_model.config.id2label[i]: round(p, 3)
38
- for i, p in enumerate(probs[0].tolist())
39
- },
40
- "Source": "Text"
41
- }
42
-
43
- return {"Error": "Please provide audio or text input."}
44
-
45
  # Building the UI
46
  gradio_ui = gr.Interface(
47
  fn=gradio_combined,
 
10
  text_model.gradient_checkpointing_enable()
11
 
12
 
13
+ def predict_emotion(audio, text):
14
+ results = {}
 
 
 
15
 
16
+ if audio is not None:
17
+ waveform, sr = torchaudio.load(audio)
18
+ preds = speech_classifier(waveform.squeeze().numpy(), sampling_rate=sr, top_k=3)
19
+ results["audio_emotion"] = preds[0]["label"]
 
20
 
21
+ if text is not None and text.strip() != "":
22
+ inputs = text_tokenizer(text, return_tensors="pt")
 
23
  with torch.no_grad():
24
  outputs = text_model(**inputs)
25
+ emotion = text_model.config.id2label[torch.argmax(outputs.logits)]
26
+ results["text_emotion"] = emotion
27
 
28
+ return results
29
+
 
 
 
 
 
 
 
 
 
 
 
 
30
  # Building the UI
31
  gradio_ui = gr.Interface(
32
  fn=gradio_combined,