sayed99's picture
project upload
cc9dfd7
"Implements various metrics to measure training accuracy"
from .torch_core import *
from .callback import *
from .layers import *
from .basic_train import LearnerCallback
__all__ = ['error_rate', 'accuracy', 'accuracy_thresh', 'dice', 'exp_rmspe', 'fbeta','FBeta', 'mse', 'mean_squared_error',
'mae', 'mean_absolute_error', 'rmse', 'root_mean_squared_error', 'msle', 'mean_squared_logarithmic_error',
'explained_variance', 'r2_score', 'top_k_accuracy', 'KappaScore', 'ConfusionMatrix', 'MatthewsCorreff',
'Precision', 'Recall', 'R2Score', 'ExplainedVariance', 'ExpRMSPE', 'RMSE', 'Perplexity', 'AUROC', 'auc_roc_score',
'roc_curve', 'MultiLabelFbeta', 'foreground_acc']
def fbeta(y_pred:Tensor, y_true:Tensor, thresh:float=0.2, beta:float=2, eps:float=1e-9, sigmoid:bool=True)->Rank0Tensor:
"Computes the f_beta between `preds` and `targets`"
beta2 = beta ** 2
if sigmoid: y_pred = y_pred.sigmoid()
y_pred = (y_pred>thresh).float()
y_true = y_true.float()
TP = (y_pred*y_true).sum(dim=1)
prec = TP/(y_pred.sum(dim=1)+eps)
rec = TP/(y_true.sum(dim=1)+eps)
res = (prec*rec)/(prec*beta2+rec+eps)*(1+beta2)
return res.mean()
def accuracy(input:Tensor, targs:Tensor)->Rank0Tensor:
"Computes accuracy with `targs` when `input` is bs * n_classes."
n = targs.shape[0]
input = input.argmax(dim=-1).view(n,-1)
targs = targs.view(n,-1)
return (input==targs).float().mean()
def accuracy_thresh(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True)->Rank0Tensor:
"Computes accuracy when `y_pred` and `y_true` are the same size."
if sigmoid: y_pred = y_pred.sigmoid()
return ((y_pred>thresh)==y_true.byte()).float().mean()
def top_k_accuracy(input:Tensor, targs:Tensor, k:int=5)->Rank0Tensor:
"Computes the Top-k accuracy (target is in the top k predictions)."
input = input.topk(k=k, dim=-1)[1]
targs = targs.unsqueeze(dim=-1).expand_as(input)
return (input == targs).max(dim=-1)[0].float().mean()
def foreground_acc(input, target, void_code):
"Computes non-background accuracy, e.g. camvid for multiclass segmentation"
target = target.squeeze(1)
mask = target != void_code
return (input.argmax(dim=1)[mask]==target[mask]).float().mean()
def error_rate(input:Tensor, targs:Tensor)->Rank0Tensor:
"1 - `accuracy`"
return 1 - accuracy(input, targs)
def dice(input:Tensor, targs:Tensor, iou:bool=False, eps:float=1e-8)->Rank0Tensor:
"Dice coefficient metric for binary target. If iou=True, returns iou metric, classic for segmentation problems."
n = targs.shape[0]
input = input.argmax(dim=1).view(n,-1)
targs = targs.view(n,-1)
intersect = (input * targs).sum().float()
union = (input+targs).sum().float()
if not iou: return (2. * intersect / union if union > 0 else union.new([1.]).squeeze())
else: return (intersect / (union-intersect+eps) if union > 0 else union.new([1.]).squeeze())
def psnr(input:Tensor, targs:Tensor)->Rank0Tensor:
return 10 * (1. / mean_squared_error(input, targs)).log10()
def exp_rmspe(pred:Tensor, targ:Tensor)->Rank0Tensor:
"Exp RMSE between `pred` and `targ`."
pred,targ = flatten_check(pred,targ)
pred, targ = torch.exp(pred), torch.exp(targ)
pct_var = (targ - pred)/targ
return torch.sqrt((pct_var**2).mean())
def mean_absolute_error(pred:Tensor, targ:Tensor)->Rank0Tensor:
"Mean absolute error between `pred` and `targ`."
pred,targ = flatten_check(pred,targ)
return torch.abs(targ - pred).mean()
def mean_squared_error(pred:Tensor, targ:Tensor)->Rank0Tensor:
"Mean squared error between `pred` and `targ`."
pred,targ = flatten_check(pred,targ)
return F.mse_loss(pred, targ)
def root_mean_squared_error(pred:Tensor, targ:Tensor)->Rank0Tensor:
"Root mean squared error between `pred` and `targ`."
pred,targ = flatten_check(pred,targ)
return torch.sqrt(F.mse_loss(pred, targ))
def mean_squared_logarithmic_error(pred:Tensor, targ:Tensor)->Rank0Tensor:
"Mean squared logarithmic error between `pred` and `targ`."
pred,targ = flatten_check(pred,targ)
return F.mse_loss(torch.log(1 + pred), torch.log(1 + targ))
def explained_variance(pred:Tensor, targ:Tensor)->Rank0Tensor:
"Explained variance between `pred` and `targ`."
pred,targ = flatten_check(pred,targ)
var_pct = torch.var(targ - pred) / torch.var(targ)
return 1 - var_pct
def r2_score(pred:Tensor, targ:Tensor)->Rank0Tensor:
"R2 score (coefficient of determination) between `pred` and `targ`."
pred,targ = flatten_check(pred,targ)
u = torch.sum((targ - pred) ** 2)
d = torch.sum((targ - targ.mean()) ** 2)
return 1 - u / d
class RegMetrics(Callback):
"Stores predictions and targets to perform calculations on epoch end."
def on_epoch_begin(self, **kwargs):
self.targs, self.preds = Tensor([]), Tensor([])
def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
assert last_output.numel() == last_target.numel(), "Expected same numbers of elements in pred & targ"
self.preds = torch.cat((self.preds, last_output.cpu()))
self.targs = torch.cat((self.targs, last_target.cpu()))
class R2Score(RegMetrics):
"Computes the R2 score (coefficient of determination)."
def on_epoch_end(self, last_metrics, **kwargs):
return add_metrics(last_metrics, r2_score(self.preds, self.targs))
class ExplainedVariance(RegMetrics):
"Computes the explained variance."
def on_epoch_end(self, last_metrics, **kwargs):
return add_metrics(last_metrics, explained_variance(self.preds, self.targs))
class RMSE(RegMetrics):
"Computes the root mean squared error."
def on_epoch_end(self, last_metrics, **kwargs):
return add_metrics(last_metrics, root_mean_squared_error(self.preds, self.targs))
class ExpRMSPE(RegMetrics):
"Computes the exponential of the root mean square error."
def on_epoch_end(self, last_metrics, **kwargs):
return add_metrics(last_metrics, exp_rmspe(self.preds, self.targs))
# Aliases
mse = mean_squared_error
mae = mean_absolute_error
msle = mean_squared_logarithmic_error
rmse = root_mean_squared_error
class ConfusionMatrix(Callback):
"Computes the confusion matrix."
def on_train_begin(self, **kwargs):
self.n_classes = 0
def on_epoch_begin(self, **kwargs):
self.cm = None
def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
preds = last_output.argmax(-1).view(-1).cpu()
targs = last_target.cpu()
if self.n_classes == 0:
self.n_classes = last_output.shape[-1]
self.x = torch.arange(0, self.n_classes)
cm = ((preds==self.x[:, None]) & (targs==self.x[:, None, None])).sum(dim=2, dtype=torch.float32)
if self.cm is None: self.cm = cm
else: self.cm += cm
def on_epoch_end(self, **kwargs):
self.metric = self.cm
@dataclass
class CMScores(ConfusionMatrix):
"Base class for metrics which rely on the calculation of the precision and/or recall score."
average:Optional[str]="binary" # `binary`, `micro`, `macro`, `weigthed` or None
pos_label:int=1 # 0 or 1
eps:float=1e-9
def _recall(self):
rec = torch.diag(self.cm) / self.cm.sum(dim=1)
if self.average is None: return rec
else:
if self.average == "micro": weights = self._weights(avg="weighted")
else: weights = self._weights(avg=self.average)
return (rec * weights).sum()
def _precision(self):
prec = torch.diag(self.cm) / self.cm.sum(dim=0)
if self.average is None: return prec
else:
weights = self._weights(avg=self.average)
return (prec * weights).sum()
def _weights(self, avg:str):
if self.n_classes != 2 and avg == "binary":
avg = self.average = "macro"
warn("average=`binary` was selected for a non binary case. Value for average has now been set to `macro` instead.")
if avg == "binary":
if self.pos_label not in (0, 1):
self.pos_label = 1
warn("Invalid value for pos_label. It has now been set to 1.")
if self.pos_label == 1: return Tensor([0,1])
else: return Tensor([1,0])
elif avg == "micro": return self.cm.sum(dim=0) / self.cm.sum()
elif avg == "macro": return torch.ones((self.n_classes,)) / self.n_classes
elif avg == "weighted": return self.cm.sum(dim=1) / self.cm.sum()
class Recall(CMScores):
"Computes the Recall."
def on_epoch_end(self, last_metrics, **kwargs):
return add_metrics(last_metrics, self._recall())
class Precision(CMScores):
"Computes the Precision."
def on_epoch_end(self, last_metrics, **kwargs):
return add_metrics(last_metrics, self._precision())
@dataclass
class FBeta(CMScores):
"Computes the F`beta` score."
beta:float=2
def on_train_begin(self, **kwargs):
self.n_classes = 0
self.beta2 = self.beta ** 2
self.avg = self.average
if self.average != "micro": self.average = None
def on_epoch_end(self, last_metrics, **kwargs):
prec = self._precision()
rec = self._recall()
metric = (1 + self.beta2) * prec * rec / (prec * self.beta2 + rec + self.eps)
metric[metric != metric] = 0 # removing potential "nan"s
if self.avg: metric = (self._weights(avg=self.avg) * metric).sum()
return add_metrics(last_metrics, metric)
def on_train_end(self, **kwargs): self.average = self.avg
@dataclass
class KappaScore(ConfusionMatrix):
"Computes the rate of agreement (Cohens Kappa)."
weights:Optional[str]=None # None, `linear`, or `quadratic`
def on_epoch_end(self, last_metrics, **kwargs):
sum0 = self.cm.sum(dim=0)
sum1 = self.cm.sum(dim=1)
expected = torch.einsum('i,j->ij', (sum0, sum1)) / sum0.sum()
if self.weights is None:
w = torch.ones((self.n_classes, self.n_classes))
w[self.x, self.x] = 0
elif self.weights == "linear" or self.weights == "quadratic":
w = torch.zeros((self.n_classes, self.n_classes))
w += torch.arange(self.n_classes, dtype=torch.float)
w = torch.abs(w - torch.t(w)) if self.weights == "linear" else (w - torch.t(w)) ** 2
else: raise ValueError('Unknown weights. Expected None, "linear", or "quadratic".')
k = torch.sum(w * self.cm) / torch.sum(w * expected)
return add_metrics(last_metrics, 1-k)
@dataclass
class MatthewsCorreff(ConfusionMatrix):
"Computes the Matthews correlation coefficient."
def on_epoch_end(self, last_metrics, **kwargs):
t_sum = self.cm.sum(dim=1)
p_sum = self.cm.sum(dim=0)
n_correct = torch.trace(self.cm)
n_samples = p_sum.sum()
cov_ytyp = n_correct * n_samples - torch.dot(t_sum, p_sum)
cov_ypyp = n_samples ** 2 - torch.dot(p_sum, p_sum)
cov_ytyt = n_samples ** 2 - torch.dot(t_sum, t_sum)
return add_metrics(last_metrics, cov_ytyp / torch.sqrt(cov_ytyt * cov_ypyp))
class Perplexity(Callback):
"Perplexity metric for language models."
def on_epoch_begin(self, **kwargs): self.loss,self.len = 0.,0
def on_batch_end(self, last_output, last_target, **kwargs):
self.loss += last_target.size(1) * CrossEntropyFlat()(last_output, last_target)
self.len += last_target.size(1)
def on_epoch_end(self, last_metrics, **kwargs):
return add_metrics(last_metrics, torch.exp(self.loss / self.len))
def auc_roc_score(input:Tensor, targ:Tensor):
"Computes the area under the receiver operator characteristic (ROC) curve using the trapezoid method. Restricted binary classification tasks."
fpr, tpr = roc_curve(input, targ)
d = fpr[1:] - fpr[:-1]
sl1, sl2 = [slice(None)], [slice(None)]
sl1[-1], sl2[-1] = slice(1, None), slice(None, -1)
return (d * (tpr[tuple(sl1)] + tpr[tuple(sl2)]) / 2.).sum(-1)
def roc_curve(input:Tensor, targ:Tensor):
"Computes the receiver operator characteristic (ROC) curve by determining the true positive ratio (TPR) and false positive ratio (FPR) for various classification thresholds. Restricted binary classification tasks."
targ = (targ == 1)
desc_score_indices = torch.flip(input.argsort(-1), [-1])
input = input[desc_score_indices]
targ = targ[desc_score_indices]
d = input[1:] - input[:-1]
distinct_value_indices = torch.nonzero(d).transpose(0,1)[0]
threshold_idxs = torch.cat((distinct_value_indices, LongTensor([len(targ) - 1]).to(targ.device)))
tps = torch.cumsum(targ * 1, dim=-1)[threshold_idxs]
fps = (1 + threshold_idxs - tps)
if tps[0] != 0 or fps[0] != 0:
fps = torch.cat((LongTensor([0]), fps))
tps = torch.cat((LongTensor([0]), tps))
fpr, tpr = fps.float() / fps[-1], tps.float() / tps[-1]
return fpr, tpr
@dataclass
class AUROC(Callback):
"Computes the area under the curve (AUC) score based on the receiver operator characteristic (ROC) curve. Restricted to binary classification tasks."
def on_epoch_begin(self, **kwargs):
self.targs, self.preds = LongTensor([]), Tensor([])
def on_batch_end(self, last_output:Tensor, last_target:Tensor, **kwargs):
last_output = F.softmax(last_output, dim=1)[:,-1]
self.preds = torch.cat((self.preds, last_output.cpu()))
self.targs = torch.cat((self.targs, last_target.cpu().long()))
def on_epoch_end(self, last_metrics, **kwargs):
return add_metrics(last_metrics, auc_roc_score(self.preds, self.targs))
class MultiLabelFbeta(LearnerCallback):
"Computes the fbeta score for multilabel classification"
# https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html
_order = -20
def __init__(self, learn, beta=2, eps=1e-15, thresh=0.3, sigmoid=True, average="micro"):
super().__init__(learn)
self.eps, self.thresh, self.sigmoid, self.average, self.beta2 = \
eps, thresh, sigmoid, average, beta**2
def on_train_begin(self, **kwargs):
self.c = self.learn.data.c
if self.average != "none": self.learn.recorder.add_metric_names([f'{self.average}_fbeta'])
else: self.learn.recorder.add_metric_names([f"fbeta_{c}" for c in self.learn.data.classes])
def on_epoch_begin(self, **kwargs):
dvc = self.learn.data.device
self.tp = torch.zeros(self.c).to(dvc)
self.total_pred = torch.zeros(self.c).to(dvc)
self.total_targ = torch.zeros(self.c).to(dvc)
def on_batch_end(self, last_output, last_target, **kwargs):
pred, targ = (last_output.sigmoid() if self.sigmoid else last_output) > self.thresh, last_target.byte()
m = pred*targ
self.tp += m.sum(0).float()
self.total_pred += pred.sum(0).float()
self.total_targ += targ.sum(0).float()
def fbeta_score(self, precision, recall):
return (1 + self.beta2)*(precision*recall)/((self.beta2*precision + recall) + self.eps)
def on_epoch_end(self, last_metrics, **kwargs):
self.total_pred += self.eps
self.total_targ += self.eps
if self.average == "micro":
precision, recall = self.tp.sum() / self.total_pred.sum(), self.tp.sum() / self.total_targ.sum()
res = self.fbeta_score(precision, recall)
elif self.average == "macro":
res = self.fbeta_score((self.tp / self.total_pred), (self.tp / self.total_targ)).mean()
elif self.average == "weighted":
scores = self.fbeta_score((self.tp / self.total_pred), (self.tp / self.total_targ))
res = (scores*self.total_targ).sum() / self.total_targ.sum()
elif self.average == "none":
res = listify(self.fbeta_score((self.tp / self.total_pred), (self.tp / self.total_targ)))
else:
raise Exception("Choose one of the average types: [micro, macro, weighted, none]")
return add_metrics(last_metrics, res)