sayed99's picture
project upload
cc9dfd7
from ..torch_core import *
from ..callback import *
from ..basic_train import Learner, LearnerCallback
__all__ = ['LossMetrics']
class LossMetrics(LearnerCallback):
"Add `loss_func.metrics` to metrics named by `loss_func.metric_names`"
_order = -20 #Needs to run before the recorder
def on_train_begin(self, **kwargs):
"Add the metrics names to the `Recorder`."
self.names = ifnone(self.learn.loss_func.metric_names, [])
if not self.names: warn('LossMetrics requested but no loss_func.metric_names provided')
self.learn.recorder.add_metric_names(self.names)
def on_epoch_begin(self, **kwargs):
"Initialize the metrics for this epoch."
self.metrics = {name:0. for name in self.names}
self.nums = 0
def on_batch_end(self, last_target, train, **kwargs):
"Update the metrics if not `train`"
if train: return
bs = last_target.size(0)
for name in self.names:
self.metrics[name] += bs * self.learn.loss_func.metrics[name].detach().cpu()
self.nums += bs
def on_epoch_end(self, last_metrics, **kwargs):
"Finish the computation and sends the result to the Recorder."
if not self.nums: return
metrics = [self.metrics[name]/self.nums for name in self.names]
return {'last_metrics': last_metrics+metrics}