rewicks commited on
Commit
92b8a8c
·
verified ·
1 Parent(s): 8e35093

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +10 -6
model.py CHANGED
@@ -136,20 +136,24 @@ class LidirlCNN(PreTrainedModel):
136
  probs = torch.softmax(logits, dim=-1)
137
  return probs
138
 
139
- def predict(self, inputs, lengths, threshold=0.5):
140
  probs = self.__call__(inputs, lengths)
 
 
 
 
 
 
141
  if self.multilabel:
142
  batch_idx, label_idx = torch.where(probs > threshold)
143
- output = []
144
  for batch, label in zip(batch_idx, label_idx):
145
- if len(output) <= batch.item():
146
- output.append([])
147
  label_string = self.labels
148
- output[-1].append(
149
  (self.labels[label.item()], probs[batch, label])
150
  )
151
  return output
152
-
153
 
154
 
155
 
 
136
  probs = torch.softmax(logits, dim=-1)
137
  return probs
138
 
139
+ def predict(self, inputs, lengths, threshold=0.5, top_k=None):
140
  probs = self.__call__(inputs, lengths)
141
+ if top_k is not None and top_k > 0:
142
+ top_k_preds = torch.topk(probs, top_k, dim=1)
143
+ pred_labels = []
144
+ for pred, prob in zip(top_k_preds.indices, top_k_preds.values):
145
+ pred_labels.append([(self.labels[p.item()], pr.item()) for (p, pr) in zip(pred, prob)])
146
+ return pred_labels
147
  if self.multilabel:
148
  batch_idx, label_idx = torch.where(probs > threshold)
149
+ output = [[] for _ in range(len(inputs))]
150
  for batch, label in zip(batch_idx, label_idx):
 
 
151
  label_string = self.labels
152
+ output[batch.item()].append(
153
  (self.labels[label.item()], probs[batch, label])
154
  )
155
  return output
156
+
157
 
158
 
159