File size: 3,785 Bytes
087f7cd
 
5c8ab8e
513a891
 
 
 
 
 
 
 
 
 
 
 
087f7cd
513a891
 
f68271b
513a891
 
 
 
 
 
 
 
5c8ab8e
 
 
513a891
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
087f7cd
513a891
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import json
from huggingface_hub import hf_hub_download
import re
import emoji
from transformers import BertForSequenceClassification, BertTokenizer

def preprocess_text(text):
    """Preprocess the input text to match training conditions."""
    text = re.sub(r'u/\w+', '[USER]', text)
    text = re.sub(r'r/\w+', '[SUBREDDIT]', text)
    text = re.sub(r'http[s]?://\S+', '[URL]', text)
    text = emoji.demojize(text, delimiters=(" ", " "))
    text = text.lower()
    return text

def load_model_and_resources():
    """Load the model, tokenizer, emotion labels, and thresholds from Hugging Face."""
    repo_id = "logasanjeev/emotions-analyzer-bert"
    
    try:
        model = BertForSequenceClassification.from_pretrained(repo_id)
        tokenizer = BertTokenizer.from_pretrained(repo_id)
    except Exception as e:
        raise RuntimeError(f"Error loading model/tokenizer: {str(e)}")

    try:
        thresholds_file = hf_hub_download(repo_id=repo_id, filename="optimized_thresholds.json")
        with open(thresholds_file, "r") as f:
            thresholds_data = json.load(f)
        if not (isinstance(thresholds_data, dict) and "emotion_labels" in thresholds_data and "thresholds" in thresholds_data):
            raise ValueError("Unexpected format in optimized_thresholds.json. Expected a dictionary with keys 'emotion_labels' and 'thresholds'.")
        emotion_labels = thresholds_data["emotion_labels"]
        thresholds = thresholds_data["thresholds"]
    except Exception as e:
        raise RuntimeError(f"Error loading thresholds: {str(e)}")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()

    return model, tokenizer, emotion_labels, thresholds, device

MODEL, TOKENIZER, EMOTION_LABELS, THRESHOLDS, DEVICE = None, None, None, None, None

def predict_emotions(text):
    """Predict emotions for the given text using the GoEmotions BERT model.
    
    Args:
        text (str): The input text to analyze.
    
    Returns:
        tuple: (predictions, processed_text)
            - predictions (str): Formatted string of predicted emotions and their confidence scores.
            - processed_text (str): The preprocessed input text.
    """
    global MODEL, TOKENIZER, EMOTION_LABELS, THRESHOLDS, DEVICE
    
    if MODEL is None:
        MODEL, TOKENIZER, EMOTION_LABELS, THRESHOLDS, DEVICE = load_model_and_resources()

    processed_text = preprocess_text(text)
    
    encodings = TOKENIZER(
        processed_text,
        padding='max_length',
        truncation=True,
        max_length=128,
        return_tensors='pt'
    )
    
    input_ids = encodings['input_ids'].to(DEVICE)
    attention_mask = encodings['attention_mask'].to(DEVICE)
    
    with torch.no_grad():
        outputs = MODEL(input_ids, attention_mask=attention_mask)
        logits = torch.sigmoid(outputs.logits).cpu().numpy()[0]
    
    predictions = []
    for i, (logit, thresh) in enumerate(zip(logits, THRESHOLDS)):
        if logit >= thresh:
            predictions.append((EMOTION_LABELS[i], round(logit, 4)))
    
    predictions.sort(key=lambda x: x[1], reverse=True)
    
    result = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in predictions]) or "No emotions predicted."
    return result, processed_text

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="Predict emotions using the GoEmotions BERT model.")
    parser.add_argument("text", type=str, help="The input text to analyze for emotions.")
    args = parser.parse_args()
    
    result, processed = predict_emotions(args.text)
    print(f"Input: {args.text}")
    print(f"Processed: {processed}")
    print("Predicted Emotions:")
    print(result)