Update model.py
Browse files
model.py
CHANGED
|
@@ -136,20 +136,24 @@ class LidirlCNN(PreTrainedModel):
|
|
| 136 |
probs = torch.softmax(logits, dim=-1)
|
| 137 |
return probs
|
| 138 |
|
| 139 |
-
def predict(self, inputs, lengths, threshold=0.5):
|
| 140 |
probs = self.__call__(inputs, lengths)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
if self.multilabel:
|
| 142 |
batch_idx, label_idx = torch.where(probs > threshold)
|
| 143 |
-
output = []
|
| 144 |
for batch, label in zip(batch_idx, label_idx):
|
| 145 |
-
if len(output) <= batch.item():
|
| 146 |
-
output.append([])
|
| 147 |
label_string = self.labels
|
| 148 |
-
output[
|
| 149 |
(self.labels[label.item()], probs[batch, label])
|
| 150 |
)
|
| 151 |
return output
|
| 152 |
-
|
| 153 |
|
| 154 |
|
| 155 |
|
|
|
|
| 136 |
probs = torch.softmax(logits, dim=-1)
|
| 137 |
return probs
|
| 138 |
|
| 139 |
+
def predict(self, inputs, lengths, threshold=0.5, top_k=None):
|
| 140 |
probs = self.__call__(inputs, lengths)
|
| 141 |
+
if top_k is not None and top_k > 0:
|
| 142 |
+
top_k_preds = torch.topk(probs, top_k, dim=1)
|
| 143 |
+
pred_labels = []
|
| 144 |
+
for pred, prob in zip(top_k_preds.indices, top_k_preds.values):
|
| 145 |
+
pred_labels.append([(self.labels[p.item()], pr.item()) for (p, pr) in zip(pred, prob)])
|
| 146 |
+
return pred_labels
|
| 147 |
if self.multilabel:
|
| 148 |
batch_idx, label_idx = torch.where(probs > threshold)
|
| 149 |
+
output = [[] for _ in range(len(inputs))]
|
| 150 |
for batch, label in zip(batch_idx, label_idx):
|
|
|
|
|
|
|
| 151 |
label_string = self.labels
|
| 152 |
+
output[batch.item()].append(
|
| 153 |
(self.labels[label.item()], probs[batch, label])
|
| 154 |
)
|
| 155 |
return output
|
| 156 |
+
|
| 157 |
|
| 158 |
|
| 159 |
|