logasanjeev commited on
Commit
d178eaa
·
verified ·
1 Parent(s): 5c8ab8e

Create onnx_inference.py

Browse files
Files changed (1) hide show
  1. onnx_inference.py +103 -0
onnx_inference.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import json
3
+ from huggingface_hub import hf_hub_download
4
+ import re
5
+ import emoji
6
+ from transformers import BertTokenizer
7
+ import onnxruntime as ort
8
+
9
+ def preprocess_text(text):
10
+ """Preprocess the input text to match training conditions."""
11
+ text = re.sub(r'u/\w+', '[USER]', text)
12
+ text = re.sub(r'r/\w+', '[SUBREDDIT]', text)
13
+ text = re.sub(r'http[s]?://\S+', '[URL]', text)
14
+ text = emoji.demojize(text, delimiters=(" ", " "))
15
+ text = text.lower()
16
+ return text
17
+
18
+ def load_model_and_resources():
19
+ """Load the ONNX model, tokenizer, emotion labels, and thresholds from Hugging Face."""
20
+ repo_id = "logasanjeev/goemotions-bert"
21
+
22
+ try:
23
+ tokenizer = BertTokenizer.from_pretrained(repo_id)
24
+ except Exception as e:
25
+ raise RuntimeError(f"Error loading tokenizer: {str(e)}")
26
+
27
+ try:
28
+ model_path = hf_hub_download(repo_id=repo_id, filename="model.onnx")
29
+ session = ort.InferenceSession(model_path)
30
+ except Exception as e:
31
+ raise RuntimeError(f"Error loading ONNX model: {str(e)}")
32
+
33
+ try:
34
+ thresholds_file = hf_hub_download(repo_id=repo_id, filename="optimized_thresholds.json")
35
+ with open(thresholds_file, "r") as f:
36
+ thresholds_data = json.load(f)
37
+ if not (isinstance(thresholds_data, dict) and "emotion_labels" in thresholds_data and "thresholds" in thresholds_data):
38
+ raise ValueError("Unexpected format in optimized_thresholds.json. Expected a dictionary with keys 'emotion_labels' and 'thresholds'.")
39
+ emotion_labels = thresholds_data["emotion_labels"]
40
+ thresholds = thresholds_data["thresholds"]
41
+ except Exception as e:
42
+ raise RuntimeError(f"Error loading thresholds: {str(e)}")
43
+
44
+ return session, tokenizer, emotion_labels, thresholds
45
+
46
+ SESSION, TOKENIZER, EMOTION_LABELS, THRESHOLDS = None, None, None, None
47
+
48
+ def predict_emotions(text):
49
+ """Predict emotions for the given text using the GoEmotions BERT ONNX model.
50
+
51
+ Args:
52
+ text (str): The input text to analyze.
53
+
54
+ Returns:
55
+ tuple: (predictions, processed_text)
56
+ - predictions (str): Formatted string of predicted emotions and their confidence scores.
57
+ - processed_text (str): The preprocessed input text.
58
+ """
59
+ global SESSION, TOKENIZER, EMOTION_LABELS, THRESHOLDS
60
+
61
+ if SESSION is None:
62
+ SESSION, TOKENIZER, EMOTION_LABELS, THRESHOLDS = load_model_and_resources()
63
+
64
+ processed_text = preprocess_text(text)
65
+
66
+ encodings = TOKENIZER(
67
+ processed_text,
68
+ padding='max_length',
69
+ truncation=True,
70
+ max_length=128,
71
+ return_tensors='np'
72
+ )
73
+
74
+ inputs = {
75
+ 'input_ids': encodings['input_ids'].astype(np.int64),
76
+ 'attention_mask': encodings['attention_mask'].astype(np.int64)
77
+ }
78
+
79
+ logits = SESSION.run(None, inputs)[0][0]
80
+ logits = 1 / (1 + np.exp(-logits)) # Sigmoid
81
+
82
+ predictions = []
83
+ for i, (logit, thresh) in enumerate(zip(logits, THRESHOLDS)):
84
+ if logit >= thresh:
85
+ predictions.append((EMOTION_LABELS[i], round(logit, 4)))
86
+
87
+ predictions.sort(key=lambda x: x[1], reverse=True)
88
+
89
+ result = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in predictions]) or "No emotions predicted."
90
+ return result, processed_text
91
+
92
+ if __name__ == "__main__":
93
+ import argparse
94
+
95
+ parser = argparse.ArgumentParser(description="Predict emotions using the GoEmotions BERT ONNX model.")
96
+ parser.add_argument("text", type=str, help="The input text to analyze for emotions.")
97
+ args = parser.parse_args()
98
+
99
+ result, processed = predict_emotions(args.text)
100
+ print(f"Input: {args.text}")
101
+ print(f"Processed: {processed}")
102
+ print("Predicted Emotions:")
103
+ print(result)