lalithadevi commited on
Commit
f4d9674
·
verified ·
1 Parent(s): aed8cf4

Upload calibrated_classifier.py

Browse files
Files changed (1) hide show
  1. calibrated_classifier.py +84 -0
calibrated_classifier.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.dummy import DummyClassifier
2
+ from tqdm import tqdm
3
+ import multiprocessing
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ from transformers import DistilBertTokenizerFast
7
+
8
+
9
+
10
+ class PredictProba(DummyClassifier):
11
+ def __init__(self, tflite_model_path: str, classes_: list, n_tokens: int):
12
+ self.classes_ = classes_ # required attribute for an estimator to be used in calibration classifier
13
+ self.n_tokens = n_tokens
14
+ self.tflite_model_path = tflite_model_path
15
+
16
+
17
+ def fit(self, x, y):
18
+ print('called fit')
19
+ return self # fit method is required for an estimator to be used in calibration classifier
20
+
21
+ @staticmethod
22
+ def get_token_batches(attention_mask, input_ids, batch_size: int=8):
23
+ n_texts = len(attention_mask)
24
+ n_batches = int(np.ceil(n_texts / batch_size))
25
+ if n_texts <= batch_size:
26
+ n_batches = 1
27
+
28
+ attention_mask_batches = []
29
+ input_ids_batches = []
30
+
31
+ for i in range(n_batches):
32
+ if i != n_batches-1:
33
+ attention_mask_batches.append(attention_mask[i*batch_size: batch_size*(i+1)])
34
+ input_ids_batches.append(input_ids[i*batch_size: batch_size*(i+1)])
35
+ else:
36
+ attention_mask_batches.append(attention_mask[i*batch_size:])
37
+ input_ids_batches.append(input_ids[i*batch_size:])
38
+
39
+ return attention_mask_batches, input_ids_batches
40
+
41
+
42
+ def get_batch_inference(self, batch_size, attention_mask, input_ids):
43
+ interpreter = tf.lite.Interpreter(model_path=self.tflite_model_path)
44
+ interpreter.allocate_tensors()
45
+ input_details = interpreter.get_input_details()
46
+ output_details = interpreter.get_output_details()[0]
47
+ interpreter.resize_tensor_input(input_details[0]['index'],[batch_size, self.n_tokens])
48
+ interpreter.resize_tensor_input(input_details[1]['index'],[batch_size, self.n_tokens])
49
+ interpreter.resize_tensor_input(output_details['index'],[batch_size, len(self.classes_)])
50
+ interpreter.allocate_tensors()
51
+ interpreter.set_tensor(input_details[0]["index"], attention_mask)
52
+ interpreter.set_tensor(input_details[1]["index"], input_ids)
53
+ interpreter.invoke()
54
+ tflite_pred = interpreter.get_tensor(output_details["index"])
55
+ return tflite_pred
56
+
57
+ def inference(self, texts):
58
+ model_checkpoint = "distilbert-base-uncased"
59
+ tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
60
+ tokens = tokenizer(texts, max_length=self.n_tokens, padding="max_length",
61
+ truncation=True, return_tensors="tf")
62
+ attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
63
+ attention_mask_batches, input_ids_batches = self.get_token_batches(attention_mask, input_ids)
64
+
65
+
66
+
67
+ pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
68
+ results = []
69
+ for attention_mask, input_ids in zip(attention_mask_batches, input_ids_batches):
70
+ f = pool.apply_async(self.get_batch_inference, args=(len(attention_mask), attention_mask, input_ids))
71
+ results.append(f)
72
+
73
+ all_predictions = np.array([])
74
+ for n_batch in tqdm(range(len(results))):
75
+ tflite_pred = results[n_batch].get(timeout=360)
76
+ if n_batch == 0:
77
+ all_predictions = tflite_pred
78
+ else:
79
+ all_predictions = np.concatenate((all_predictions, tflite_pred), axis=0)
80
+ return all_predictions
81
+
82
+ def predict_proba(self, X, y=None):
83
+ predict_prob = self.inference(X)
84
+ return predict_prob