|
"Implements various metrics to measure training accuracy" |
|
from .torch_core import * |
|
from .callback import * |
|
from .layers import * |
|
from .basic_train import LearnerCallback |
|
|
|
__all__ = ['error_rate', 'accuracy', 'accuracy_thresh', 'dice', 'exp_rmspe', 'fbeta','FBeta', 'mse', 'mean_squared_error', |
|
'mae', 'mean_absolute_error', 'rmse', 'root_mean_squared_error', 'msle', 'mean_squared_logarithmic_error', |
|
'explained_variance', 'r2_score', 'top_k_accuracy', 'KappaScore', 'ConfusionMatrix', 'MatthewsCorreff', |
|
'Precision', 'Recall', 'R2Score', 'ExplainedVariance', 'ExpRMSPE', 'RMSE', 'Perplexity', 'AUROC', 'auc_roc_score', |
|
'roc_curve', 'MultiLabelFbeta', 'foreground_acc'] |
|
|
|
def fbeta(y_pred:Tensor, y_true:Tensor, thresh:float=0.2, beta:float=2, eps:float=1e-9, sigmoid:bool=True)->Rank0Tensor: |
|
"Computes the f_beta between `preds` and `targets`" |
|
beta2 = beta ** 2 |
|
if sigmoid: y_pred = y_pred.sigmoid() |
|
y_pred = (y_pred>thresh).float() |
|
y_true = y_true.float() |
|
TP = (y_pred*y_true).sum(dim=1) |
|
prec = TP/(y_pred.sum(dim=1)+eps) |
|
rec = TP/(y_true.sum(dim=1)+eps) |
|
res = (prec*rec)/(prec*beta2+rec+eps)*(1+beta2) |
|
return res.mean() |
|
|
|
def accuracy(input:Tensor, targs:Tensor)->Rank0Tensor: |
|
"Computes accuracy with `targs` when `input` is bs * n_classes." |
|
n = targs.shape[0] |
|
input = input.argmax(dim=-1).view(n,-1) |
|
targs = targs.view(n,-1) |
|
return (input==targs).float().mean() |
|
|
|
def accuracy_thresh(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True)->Rank0Tensor: |
|
"Computes accuracy when `y_pred` and `y_true` are the same size." |
|
if sigmoid: y_pred = y_pred.sigmoid() |
|
return ((y_pred>thresh)==y_true.byte()).float().mean() |
|
|
|
def top_k_accuracy(input:Tensor, targs:Tensor, k:int=5)->Rank0Tensor: |
|
"Computes the Top-k accuracy (target is in the top k predictions)." |
|
input = input.topk(k=k, dim=-1)[1] |
|
targs = targs.unsqueeze(dim=-1).expand_as(input) |
|
return (input == targs).max(dim=-1)[0].float().mean() |
|
|
|
def foreground_acc(input, target, void_code): |
|
"Computes non-background accuracy, e.g. camvid for multiclass segmentation" |
|
target = target.squeeze(1) |
|
mask = target != void_code |
|
return (input.argmax(dim=1)[mask]==target[mask]).float().mean() |
|
|
|
def error_rate(input:Tensor, targs:Tensor)->Rank0Tensor: |
|
"1 - `accuracy`" |
|
return 1 - accuracy(input, targs) |
|
|
|
def dice(input:Tensor, targs:Tensor, iou:bool=False, eps:float=1e-8)->Rank0Tensor: |
|
"Dice coefficient metric for binary target. If iou=True, returns iou metric, classic for segmentation problems." |
|
n = targs.shape[0] |
|
input = input.argmax(dim=1).view(n,-1) |
|
targs = targs.view(n,-1) |
|
intersect = (input * targs).sum().float() |
|
union = (input+targs).sum().float() |
|
if not iou: return (2. * intersect / union if union > 0 else union.new([1.]).squeeze()) |
|
else: return (intersect / (union-intersect+eps) if union > 0 else union.new([1.]).squeeze()) |
|
|
|
def psnr(input:Tensor, targs:Tensor)->Rank0Tensor: |
|
return 10 * (1. / mean_squared_error(input, targs)).log10() |
|
|
|
def exp_rmspe(pred:Tensor, targ:Tensor)->Rank0Tensor: |
|
"Exp RMSE between `pred` and `targ`." |
|
pred,targ = flatten_check(pred,targ) |
|
pred, targ = torch.exp(pred), torch.exp(targ) |
|
pct_var = (targ - pred)/targ |
|
return torch.sqrt((pct_var**2).mean()) |
|
|
|
def mean_absolute_error(pred:Tensor, targ:Tensor)->Rank0Tensor: |
|
"Mean absolute error between `pred` and `targ`." |
|
pred,targ = flatten_check(pred,targ) |
|
return torch.abs(targ - pred).mean() |
|
|
|
def mean_squared_error(pred:Tensor, targ:Tensor)->Rank0Tensor: |
|
"Mean squared error between `pred` and `targ`." |
|
pred,targ = flatten_check(pred,targ) |
|
return F.mse_loss(pred, targ) |
|
|
|
def root_mean_squared_error(pred:Tensor, targ:Tensor)->Rank0Tensor: |
|
"Root mean squared error between `pred` and `targ`." |
|
pred,targ = flatten_check(pred,targ) |
|
return torch.sqrt(F.mse_loss(pred, targ)) |
|
|
|
def mean_squared_logarithmic_error(pred:Tensor, targ:Tensor)->Rank0Tensor: |
|
"Mean squared logarithmic error between `pred` and `targ`." |
|
pred,targ = flatten_check(pred,targ) |
|
return F.mse_loss(torch.log(1 + pred), torch.log(1 + targ)) |
|
|
|
def explained_variance(pred:Tensor, targ:Tensor)->Rank0Tensor: |
|
"Explained variance between `pred` and `targ`." |
|
pred,targ = flatten_check(pred,targ) |
|
var_pct = torch.var(targ - pred) / torch.var(targ) |
|
return 1 - var_pct |
|
|
|
def r2_score(pred:Tensor, targ:Tensor)->Rank0Tensor: |
|
"R2 score (coefficient of determination) between `pred` and `targ`." |
|
pred,targ = flatten_check(pred,targ) |
|
u = torch.sum((targ - pred) ** 2) |
|
d = torch.sum((targ - targ.mean()) ** 2) |
|
return 1 - u / d |
|
|
|
class RegMetrics(Callback): |
|
"Stores predictions and targets to perform calculations on epoch end." |
|
def on_epoch_begin(self, **kwargs): |
|
self.targs, self.preds = Tensor([]), Tensor([]) |
|
|
|
def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs): |
|
assert last_output.numel() == last_target.numel(), "Expected same numbers of elements in pred & targ" |
|
self.preds = torch.cat((self.preds, last_output.cpu())) |
|
self.targs = torch.cat((self.targs, last_target.cpu())) |
|
|
|
class R2Score(RegMetrics): |
|
"Computes the R2 score (coefficient of determination)." |
|
def on_epoch_end(self, last_metrics, **kwargs): |
|
return add_metrics(last_metrics, r2_score(self.preds, self.targs)) |
|
|
|
class ExplainedVariance(RegMetrics): |
|
"Computes the explained variance." |
|
def on_epoch_end(self, last_metrics, **kwargs): |
|
return add_metrics(last_metrics, explained_variance(self.preds, self.targs)) |
|
|
|
class RMSE(RegMetrics): |
|
"Computes the root mean squared error." |
|
def on_epoch_end(self, last_metrics, **kwargs): |
|
return add_metrics(last_metrics, root_mean_squared_error(self.preds, self.targs)) |
|
|
|
class ExpRMSPE(RegMetrics): |
|
"Computes the exponential of the root mean square error." |
|
def on_epoch_end(self, last_metrics, **kwargs): |
|
return add_metrics(last_metrics, exp_rmspe(self.preds, self.targs)) |
|
|
|
|
|
mse = mean_squared_error |
|
mae = mean_absolute_error |
|
msle = mean_squared_logarithmic_error |
|
rmse = root_mean_squared_error |
|
|
|
class ConfusionMatrix(Callback): |
|
"Computes the confusion matrix." |
|
|
|
def on_train_begin(self, **kwargs): |
|
self.n_classes = 0 |
|
|
|
def on_epoch_begin(self, **kwargs): |
|
self.cm = None |
|
|
|
def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs): |
|
preds = last_output.argmax(-1).view(-1).cpu() |
|
targs = last_target.cpu() |
|
if self.n_classes == 0: |
|
self.n_classes = last_output.shape[-1] |
|
self.x = torch.arange(0, self.n_classes) |
|
cm = ((preds==self.x[:, None]) & (targs==self.x[:, None, None])).sum(dim=2, dtype=torch.float32) |
|
if self.cm is None: self.cm = cm |
|
else: self.cm += cm |
|
|
|
def on_epoch_end(self, **kwargs): |
|
self.metric = self.cm |
|
|
|
@dataclass |
|
class CMScores(ConfusionMatrix): |
|
"Base class for metrics which rely on the calculation of the precision and/or recall score." |
|
average:Optional[str]="binary" |
|
pos_label:int=1 |
|
eps:float=1e-9 |
|
|
|
def _recall(self): |
|
rec = torch.diag(self.cm) / self.cm.sum(dim=1) |
|
if self.average is None: return rec |
|
else: |
|
if self.average == "micro": weights = self._weights(avg="weighted") |
|
else: weights = self._weights(avg=self.average) |
|
return (rec * weights).sum() |
|
|
|
def _precision(self): |
|
prec = torch.diag(self.cm) / self.cm.sum(dim=0) |
|
if self.average is None: return prec |
|
else: |
|
weights = self._weights(avg=self.average) |
|
return (prec * weights).sum() |
|
|
|
def _weights(self, avg:str): |
|
if self.n_classes != 2 and avg == "binary": |
|
avg = self.average = "macro" |
|
warn("average=`binary` was selected for a non binary case. Value for average has now been set to `macro` instead.") |
|
if avg == "binary": |
|
if self.pos_label not in (0, 1): |
|
self.pos_label = 1 |
|
warn("Invalid value for pos_label. It has now been set to 1.") |
|
if self.pos_label == 1: return Tensor([0,1]) |
|
else: return Tensor([1,0]) |
|
elif avg == "micro": return self.cm.sum(dim=0) / self.cm.sum() |
|
elif avg == "macro": return torch.ones((self.n_classes,)) / self.n_classes |
|
elif avg == "weighted": return self.cm.sum(dim=1) / self.cm.sum() |
|
|
|
|
|
class Recall(CMScores): |
|
"Computes the Recall." |
|
def on_epoch_end(self, last_metrics, **kwargs): |
|
return add_metrics(last_metrics, self._recall()) |
|
|
|
class Precision(CMScores): |
|
"Computes the Precision." |
|
def on_epoch_end(self, last_metrics, **kwargs): |
|
return add_metrics(last_metrics, self._precision()) |
|
|
|
@dataclass |
|
class FBeta(CMScores): |
|
"Computes the F`beta` score." |
|
beta:float=2 |
|
|
|
def on_train_begin(self, **kwargs): |
|
self.n_classes = 0 |
|
self.beta2 = self.beta ** 2 |
|
self.avg = self.average |
|
if self.average != "micro": self.average = None |
|
|
|
def on_epoch_end(self, last_metrics, **kwargs): |
|
prec = self._precision() |
|
rec = self._recall() |
|
metric = (1 + self.beta2) * prec * rec / (prec * self.beta2 + rec + self.eps) |
|
metric[metric != metric] = 0 |
|
if self.avg: metric = (self._weights(avg=self.avg) * metric).sum() |
|
return add_metrics(last_metrics, metric) |
|
|
|
def on_train_end(self, **kwargs): self.average = self.avg |
|
|
|
@dataclass |
|
class KappaScore(ConfusionMatrix): |
|
"Computes the rate of agreement (Cohens Kappa)." |
|
weights:Optional[str]=None |
|
|
|
def on_epoch_end(self, last_metrics, **kwargs): |
|
sum0 = self.cm.sum(dim=0) |
|
sum1 = self.cm.sum(dim=1) |
|
expected = torch.einsum('i,j->ij', (sum0, sum1)) / sum0.sum() |
|
if self.weights is None: |
|
w = torch.ones((self.n_classes, self.n_classes)) |
|
w[self.x, self.x] = 0 |
|
elif self.weights == "linear" or self.weights == "quadratic": |
|
w = torch.zeros((self.n_classes, self.n_classes)) |
|
w += torch.arange(self.n_classes, dtype=torch.float) |
|
w = torch.abs(w - torch.t(w)) if self.weights == "linear" else (w - torch.t(w)) ** 2 |
|
else: raise ValueError('Unknown weights. Expected None, "linear", or "quadratic".') |
|
k = torch.sum(w * self.cm) / torch.sum(w * expected) |
|
return add_metrics(last_metrics, 1-k) |
|
|
|
@dataclass |
|
class MatthewsCorreff(ConfusionMatrix): |
|
"Computes the Matthews correlation coefficient." |
|
def on_epoch_end(self, last_metrics, **kwargs): |
|
t_sum = self.cm.sum(dim=1) |
|
p_sum = self.cm.sum(dim=0) |
|
n_correct = torch.trace(self.cm) |
|
n_samples = p_sum.sum() |
|
cov_ytyp = n_correct * n_samples - torch.dot(t_sum, p_sum) |
|
cov_ypyp = n_samples ** 2 - torch.dot(p_sum, p_sum) |
|
cov_ytyt = n_samples ** 2 - torch.dot(t_sum, t_sum) |
|
return add_metrics(last_metrics, cov_ytyp / torch.sqrt(cov_ytyt * cov_ypyp)) |
|
|
|
class Perplexity(Callback): |
|
"Perplexity metric for language models." |
|
def on_epoch_begin(self, **kwargs): self.loss,self.len = 0.,0 |
|
|
|
def on_batch_end(self, last_output, last_target, **kwargs): |
|
self.loss += last_target.size(1) * CrossEntropyFlat()(last_output, last_target) |
|
self.len += last_target.size(1) |
|
|
|
def on_epoch_end(self, last_metrics, **kwargs): |
|
return add_metrics(last_metrics, torch.exp(self.loss / self.len)) |
|
|
|
def auc_roc_score(input:Tensor, targ:Tensor): |
|
"Computes the area under the receiver operator characteristic (ROC) curve using the trapezoid method. Restricted binary classification tasks." |
|
fpr, tpr = roc_curve(input, targ) |
|
d = fpr[1:] - fpr[:-1] |
|
sl1, sl2 = [slice(None)], [slice(None)] |
|
sl1[-1], sl2[-1] = slice(1, None), slice(None, -1) |
|
return (d * (tpr[tuple(sl1)] + tpr[tuple(sl2)]) / 2.).sum(-1) |
|
|
|
def roc_curve(input:Tensor, targ:Tensor): |
|
"Computes the receiver operator characteristic (ROC) curve by determining the true positive ratio (TPR) and false positive ratio (FPR) for various classification thresholds. Restricted binary classification tasks." |
|
targ = (targ == 1) |
|
desc_score_indices = torch.flip(input.argsort(-1), [-1]) |
|
input = input[desc_score_indices] |
|
targ = targ[desc_score_indices] |
|
d = input[1:] - input[:-1] |
|
distinct_value_indices = torch.nonzero(d).transpose(0,1)[0] |
|
threshold_idxs = torch.cat((distinct_value_indices, LongTensor([len(targ) - 1]).to(targ.device))) |
|
tps = torch.cumsum(targ * 1, dim=-1)[threshold_idxs] |
|
fps = (1 + threshold_idxs - tps) |
|
if tps[0] != 0 or fps[0] != 0: |
|
fps = torch.cat((LongTensor([0]), fps)) |
|
tps = torch.cat((LongTensor([0]), tps)) |
|
fpr, tpr = fps.float() / fps[-1], tps.float() / tps[-1] |
|
return fpr, tpr |
|
|
|
@dataclass |
|
class AUROC(Callback): |
|
"Computes the area under the curve (AUC) score based on the receiver operator characteristic (ROC) curve. Restricted to binary classification tasks." |
|
def on_epoch_begin(self, **kwargs): |
|
self.targs, self.preds = LongTensor([]), Tensor([]) |
|
|
|
def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs): |
|
last_output = F.softmax(last_output, dim=1)[:,-1] |
|
self.preds = torch.cat((self.preds, last_output.cpu())) |
|
self.targs = torch.cat((self.targs, last_target.cpu().long())) |
|
|
|
def on_epoch_end(self, last_metrics, **kwargs): |
|
return add_metrics(last_metrics, auc_roc_score(self.preds, self.targs)) |
|
|
|
class MultiLabelFbeta(LearnerCallback): |
|
"Computes the fbeta score for multilabel classification" |
|
|
|
_order = -20 |
|
def __init__(self, learn, beta=2, eps=1e-15, thresh=0.3, sigmoid=True, average="micro"): |
|
super().__init__(learn) |
|
self.eps, self.thresh, self.sigmoid, self.average, self.beta2 = \ |
|
eps, thresh, sigmoid, average, beta**2 |
|
|
|
def on_train_begin(self, **kwargs): |
|
self.c = self.learn.data.c |
|
if self.average != "none": self.learn.recorder.add_metric_names([f'{self.average}_fbeta']) |
|
else: self.learn.recorder.add_metric_names([f"fbeta_{c}" for c in self.learn.data.classes]) |
|
|
|
def on_epoch_begin(self, **kwargs): |
|
dvc = self.learn.data.device |
|
self.tp = torch.zeros(self.c).to(dvc) |
|
self.total_pred = torch.zeros(self.c).to(dvc) |
|
self.total_targ = torch.zeros(self.c).to(dvc) |
|
|
|
def on_batch_end(self, last_output, last_target, **kwargs): |
|
pred, targ = (last_output.sigmoid() if self.sigmoid else last_output) > self.thresh, last_target.byte() |
|
m = pred*targ |
|
self.tp += m.sum(0).float() |
|
self.total_pred += pred.sum(0).float() |
|
self.total_targ += targ.sum(0).float() |
|
|
|
def fbeta_score(self, precision, recall): |
|
return (1 + self.beta2)*(precision*recall)/((self.beta2*precision + recall) + self.eps) |
|
|
|
def on_epoch_end(self, last_metrics, **kwargs): |
|
self.total_pred += self.eps |
|
self.total_targ += self.eps |
|
if self.average == "micro": |
|
precision, recall = self.tp.sum() / self.total_pred.sum(), self.tp.sum() / self.total_targ.sum() |
|
res = self.fbeta_score(precision, recall) |
|
elif self.average == "macro": |
|
res = self.fbeta_score((self.tp / self.total_pred), (self.tp / self.total_targ)).mean() |
|
elif self.average == "weighted": |
|
scores = self.fbeta_score((self.tp / self.total_pred), (self.tp / self.total_targ)) |
|
res = (scores*self.total_targ).sum() / self.total_targ.sum() |
|
elif self.average == "none": |
|
res = listify(self.fbeta_score((self.tp / self.total_pred), (self.tp / self.total_targ))) |
|
else: |
|
raise Exception("Choose one of the average types: [micro, macro, weighted, none]") |
|
|
|
return add_metrics(last_metrics, res) |
|
|