File size: 3,797 Bytes
d178eaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c657ca2
d178eaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import numpy as np
import json
from huggingface_hub import hf_hub_download
import re
import emoji
from transformers import BertTokenizer
import onnxruntime as ort

def preprocess_text(text):
    """Preprocess the input text to match training conditions."""
    text = re.sub(r'u/\w+', '[USER]', text)
    text = re.sub(r'r/\w+', '[SUBREDDIT]', text)
    text = re.sub(r'http[s]?://\S+', '[URL]', text)
    text = emoji.demojize(text, delimiters=(" ", " "))
    text = text.lower()
    return text

def load_model_and_resources():
    """Load the ONNX model, tokenizer, emotion labels, and thresholds from Hugging Face."""
    repo_id = "logasanjeev/emotions-analyzer-bert"
    
    try:
        tokenizer = BertTokenizer.from_pretrained(repo_id)
    except Exception as e:
        raise RuntimeError(f"Error loading tokenizer: {str(e)}")

    try:
        model_path = hf_hub_download(repo_id=repo_id, filename="model.onnx")
        session = ort.InferenceSession(model_path)
    except Exception as e:
        raise RuntimeError(f"Error loading ONNX model: {str(e)}")

    try:
        thresholds_file = hf_hub_download(repo_id=repo_id, filename="optimized_thresholds.json")
        with open(thresholds_file, "r") as f:
            thresholds_data = json.load(f)
        if not (isinstance(thresholds_data, dict) and "emotion_labels" in thresholds_data and "thresholds" in thresholds_data):
            raise ValueError("Unexpected format in optimized_thresholds.json. Expected a dictionary with keys 'emotion_labels' and 'thresholds'.")
        emotion_labels = thresholds_data["emotion_labels"]
        thresholds = thresholds_data["thresholds"]
    except Exception as e:
        raise RuntimeError(f"Error loading thresholds: {str(e)}")

    return session, tokenizer, emotion_labels, thresholds

SESSION, TOKENIZER, EMOTION_LABELS, THRESHOLDS = None, None, None, None

def predict_emotions(text):
    """Predict emotions for the given text using the GoEmotions BERT ONNX model.
    
    Args:
        text (str): The input text to analyze.
    
    Returns:
        tuple: (predictions, processed_text)
            - predictions (str): Formatted string of predicted emotions and their confidence scores.
            - processed_text (str): The preprocessed input text.
    """
    global SESSION, TOKENIZER, EMOTION_LABELS, THRESHOLDS
    
    if SESSION is None:
        SESSION, TOKENIZER, EMOTION_LABELS, THRESHOLDS = load_model_and_resources()

    processed_text = preprocess_text(text)
    
    encodings = TOKENIZER(
        processed_text,
        padding='max_length',
        truncation=True,
        max_length=128,
        return_tensors='np'
    )
    
    inputs = {
        'input_ids': encodings['input_ids'].astype(np.int64),
        'attention_mask': encodings['attention_mask'].astype(np.int64)
    }
    
    logits = SESSION.run(None, inputs)[0][0]
    logits = 1 / (1 + np.exp(-logits))  # Sigmoid
    
    predictions = []
    for i, (logit, thresh) in enumerate(zip(logits, THRESHOLDS)):
        if logit >= thresh:
            predictions.append((EMOTION_LABELS[i], round(logit, 4)))
    
    predictions.sort(key=lambda x: x[1], reverse=True)
    
    result = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in predictions]) or "No emotions predicted."
    return result, processed_text

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Predict emotions using the GoEmotions BERT ONNX model.")
    parser.add_argument("text", type=str, help="The input text to analyze for emotions.")
    args = parser.parse_args()
    
    result, processed = predict_emotions(args.text)
    print(f"Input: {args.text}")
    print(f"Processed: {processed}")
    print("Predicted Emotions:")
    print(result)