sayed99's picture
project upload
cc9dfd7
"Callbacks provides extensibility to the `basic_train` loop. See `train` for examples of custom callbacks."
from .basic_data import *
from .torch_core import *
import torch.distributed as dist
__all__ = ['AverageMetric', 'Callback', 'CallbackHandler', 'OptimWrapper', 'SmoothenValue', 'Scheduler', 'annealing_cos', 'CallbackList',
'annealing_exp', 'annealing_linear', 'annealing_no', 'annealing_poly']
class OptimWrapper():
"Basic wrapper around `opt` to simplify hyper-parameters changes."
def __init__(self, opt:optim.Optimizer, wd:Floats=0., true_wd:bool=False, bn_wd:bool=True):
assert not isinstance(opt, OptimWrapper)
self.opt,self.true_wd,self.bn_wd = opt,true_wd,bn_wd
self.opt_keys = list(self.opt.param_groups[0].keys())
self.opt_keys.remove('params')
self.read_defaults()
self.wd = wd
@classmethod
def create(cls, opt_func:Union[type,Callable], lr:Union[float,Tuple,List], layer_groups:ModuleList, wd:Floats=0.,
true_wd:bool=False, bn_wd:bool=True)->optim.Optimizer:
"Create an `optim.Optimizer` from `opt_func` with `lr`. Set lr on `layer_groups`."
split_params = split_no_wd_params(layer_groups)
opt = opt_func([{'params': p, 'lr':0} for p in split_params])
opt = cls(opt, wd=wd, true_wd=true_wd, bn_wd=bn_wd)
opt.lr,opt.opt_func = listify(lr, layer_groups),opt_func
return opt
def new(self, layer_groups:Collection[nn.Module], split_no_wd:bool=True):
"Create a new `OptimWrapper` from `self` with another `layer_groups` but the same hyper-parameters."
opt_func = getattr(self, 'opt_func', self.opt.__class__)
res = self.create(opt_func, self.lr, layer_groups, wd=self.wd, true_wd=self.true_wd, bn_wd=self.bn_wd)
res.mom,res.beta = self.mom,self.beta
return res
def new_with_params(self, param_groups:Collection[Collection[nn.Parameter]]):
"Create a new `OptimWrapper` from `self` with another `layer_groups` but the same hyper-parameters."
opt_func = getattr(self, 'opt_func', self.opt.__class__)
opt = opt_func([{'params': p, 'lr':0} for p in param_groups])
opt = self.__class__(opt, wd=self.wd, true_wd=self.true_wd, bn_wd=self.bn_wd)
opt.lr,opt.opt_func,opt.mom,opt.beta = self.lr,opt_func,self.mom,self.beta
return opt
def __repr__(self)->str:
return f'OptimWrapper over {repr(self.opt)}.\nTrue weight decay: {self.true_wd}'
#Pytorch optimizer methods
def step(self)->None:
"Set weight decay and step optimizer."
# weight decay outside of optimizer step (AdamW)
if self.true_wd:
for lr,wd,pg1,pg2 in zip(self._lr,self._wd,self.opt.param_groups[::2],self.opt.param_groups[1::2]):
for p in pg1['params']: p.data.mul_(1 - wd*lr)
if self.bn_wd:
for p in pg2['params']: p.data.mul_(1 - wd*lr)
self.set_val('weight_decay', listify(0, self._wd))
self.opt.step()
def zero_grad(self)->None:
"Clear optimizer gradients."
self.opt.zero_grad()
#Passthrough to the inner opt.
def __getattr__(self, k:str)->Any: return getattr(self.opt, k, None)
def __setstate__(self,data:Any): self.__dict__.update(data)
def clear(self):
"Reset the state of the inner optimizer."
sd = self.state_dict()
sd['state'] = {}
self.load_state_dict(sd)
@property
def n_params(self): return sum([len(pg['params']) for pg in self.opt.param_groups])
#Hyperparameters as properties
@property
def lr(self)->float: return self._lr[-1]
@lr.setter
def lr(self, val:float)->None:
self._lr = self.set_val('lr', listify(val, self._lr))
@property
def mom(self)->float:return self._mom[-1]
@mom.setter
def mom(self, val:float)->None:
if 'momentum' in self.opt_keys: self.set_val('momentum', listify(val, self._mom))
elif 'betas' in self.opt_keys: self.set_val('betas', (listify(val, self._mom), self._beta))
self._mom = listify(val, self._mom)
@property
def beta(self)->float: return None if self._beta is None else self._beta[-1]
@beta.setter
def beta(self, val:float)->None:
"Set beta (or alpha as makes sense for given optimizer)."
if val is None: return
if 'betas' in self.opt_keys: self.set_val('betas', (self._mom, listify(val, self._beta)))
elif 'alpha' in self.opt_keys: self.set_val('alpha', listify(val, self._beta))
self._beta = listify(val, self._beta)
@property
def wd(self)->float: return self._wd[-1]
@wd.setter
def wd(self, val:float)->None:
"Set weight decay."
if not self.true_wd: self.set_val('weight_decay', listify(val, self._wd), bn_groups=self.bn_wd)
self._wd = listify(val, self._wd)
#Helper functions
def read_defaults(self)->None:
"Read the values inside the optimizer for the hyper-parameters."
self._beta = None
if 'lr' in self.opt_keys: self._lr = self.read_val('lr')
if 'momentum' in self.opt_keys: self._mom = self.read_val('momentum')
if 'alpha' in self.opt_keys: self._beta = self.read_val('alpha')
if 'betas' in self.opt_keys: self._mom,self._beta = self.read_val('betas')
if 'weight_decay' in self.opt_keys: self._wd = self.read_val('weight_decay')
reserved_names = ['params', 'lr', 'momentum', 'alpha', 'betas', 'weight_decay']
stat_names = [n for n in self.opt_keys if n not in reserved_names]
self._stats = {n:self.read_val(n) for n in stat_names}
def get_stat(self, name:str)->float:
if name in ['lr', 'mom', 'beta', 'wd']: return getattr(self, name)
else: return self._stats[name][-1]
def set_stat(self, name:str, value:Union[float, Collection[float]])->None:
if name in ['lr', 'mom', 'beta', 'wd']: setattr(self, name, value)
else:
val = listify(value, self._stats[name])
self.set_val(name, val)
self._stats[name] = val
def set_val(self, key:str, val:Any, bn_groups:bool=True)->Any:
"Set `val` inside the optimizer dictionary at `key`."
if is_tuple(val): val = [(v1,v2) for v1,v2 in zip(*val)]
for v,pg1,pg2 in zip(val,self.opt.param_groups[::2],self.opt.param_groups[1::2]):
pg1[key] = v
if bn_groups: pg2[key] = v
return val
def read_val(self, key:str) -> Union[List[float],Tuple[List[float],List[float]]]:
"Read a hyperparameter `key` in the optimizer dictionary."
val = [pg[key] for pg in self.opt.param_groups[::2]]
if is_tuple(val[0]): val = [o[0] for o in val], [o[1] for o in val]
return val
def get_state(self):
"Return the inner state minus the layer groups."
return {'opt_state':self.opt.state_dict(), 'lr':self._lr, 'wd':self._wd, 'beta':self._beta, 'mom':self._mom,
'opt_func':self.opt_func, 'true_wd':self.true_wd, 'bn_wd':self.bn_wd}
@classmethod
def load_with_state_and_layer_group(cls, state:dict, layer_groups:Collection[nn.Module]):
res = cls.create(state['opt_func'], state['lr'], layer_groups, wd=state['wd'], true_wd=state['true_wd'],
bn_wd=state['bn_wd'])
res._mom,res._beta = state['mom'],state['beta']
res.load_state_dict(state['opt_state'])
return res
class Callback():
"Base class for callbacks that want to record values, dynamically change learner params, etc."
_order=0
def on_train_begin(self, **kwargs:Any)->None:
"To initialize constants in the callback."
pass
def on_epoch_begin(self, **kwargs:Any)->None:
"At the beginning of each epoch."
pass
def on_batch_begin(self, **kwargs:Any)->None:
"Set HP before the output and loss are computed."
pass
def on_loss_begin(self, **kwargs:Any)->None:
"Called after forward pass but before loss has been computed."
pass
def on_backward_begin(self, **kwargs:Any)->None:
"Called after the forward pass and the loss has been computed, but before backprop."
pass
def on_backward_end(self, **kwargs:Any)->None:
"Called after backprop but before optimizer step. Useful for true weight decay in AdamW."
pass
def on_step_end(self, **kwargs:Any)->None:
"Called after the step of the optimizer but before the gradients are zeroed."
pass
def on_batch_end(self, **kwargs:Any)->None:
"Called at the end of the batch."
pass
def on_epoch_end(self, **kwargs:Any)->None:
"Called at the end of an epoch."
pass
def on_train_end(self, **kwargs:Any)->None:
"Useful for cleaning up things and saving files/models."
pass
def jump_to_epoch(self, epoch)->None:
"To resume training at `epoch` directly."
pass
def get_state(self, minimal:bool=True):
"Return the inner state of the `Callback`, `minimal` or not."
to_remove = ['exclude', 'not_min'] + getattr(self, 'exclude', []).copy()
if minimal: to_remove += getattr(self, 'not_min', []).copy()
return {k:v for k,v in self.__dict__.items() if k not in to_remove}
def __repr__(self):
attrs = func_args(self.__init__)
to_remove = getattr(self, 'exclude', [])
list_repr = [self.__class__.__name__] + [f'{k}: {getattr(self, k)}' for k in attrs if k != 'self' and k not in to_remove]
return '\n'.join(list_repr)
class SmoothenValue():
"Create a smooth moving average for a value (loss, etc) using `beta`."
def __init__(self, beta:float):
self.beta,self.n,self.mov_avg = beta,0,0
def add_value(self, val:float)->None:
"Add `val` to calculate updated smoothed value."
self.n += 1
self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val
self.smooth = self.mov_avg / (1 - self.beta ** self.n)
CallbackList = Collection[Callback]
def _get_init_state(): return {'epoch':0, 'iteration':0, 'num_batch':0, 'skip_validate': False}
@dataclass
class CallbackHandler():
"Manage all of the registered `callbacks` and `metrics`, smoothing loss by momentum `beta`."
callbacks:CallbackList=None
metrics:CallbackList=None
beta:float=0.98
def __post_init__(self)->None:
"Initialize smoother and learning stats."
self.callbacks = ifnone(self.callbacks, [])
self.metrics = ifnone(self.metrics, [])
self.metrics = [(met if isinstance(met, Callback) else AverageMetric(met)) for met in self.metrics]
self.callbacks = sorted(self.callbacks, key=lambda o: getattr(o, '_order', 0))
self.smoothener = SmoothenValue(self.beta)
self.state_dict:Dict[str,Union[int,float,Tensor]]=_get_init_state()
def _call_and_update(self, cb, cb_name, **kwargs)->None:
"Call `cb_name` on `cb` and update the inner state."
new = ifnone(getattr(cb, f'on_{cb_name}')(**self.state_dict, **kwargs), dict())
for k,v in new.items():
if k not in self.state_dict:
raise Exception(f"{k} isn't a valid key in the state of the callbacks.")
else: self.state_dict[k] = v
def __call__(self, cb_name, call_mets=True, **kwargs)->None:
"Call through to all of the `CallbakHandler` functions."
if call_mets:
for met in self.metrics: self._call_and_update(met, cb_name, **kwargs)
for cb in self.callbacks: self._call_and_update(cb, cb_name, **kwargs)
def set_dl(self, dl:DataLoader):
"Set the current `dl` used."
if hasattr(self, 'cb_dl'): self.callbacks.remove(self.cb_dl)
if isinstance(dl.dataset, Callback):
self.callbacks.append(dl.dataset)
self.cb_dl = dl.dataset
def on_train_begin(self, epochs:int, pbar:PBar, metrics:MetricFuncList)->None:
"About to start learning."
self.state_dict = _get_init_state()
self.state_dict.update(dict(n_epochs=epochs, pbar=pbar, metrics=metrics))
names = [(met.name if hasattr(met, 'name') else camel2snake(met.__class__.__name__)) for met in self.metrics]
self('train_begin', metrics_names=names)
if self.state_dict['epoch'] != 0:
self.state_dict['pbar'].first_bar.total -= self.state_dict['epoch']
for cb in self.callbacks: cb.jump_to_epoch(self.state_dict['epoch'])
def on_epoch_begin(self)->None:
"Handle new epoch."
self.state_dict['num_batch'],self.state_dict['stop_training'] = 0,False
self('epoch_begin')
def on_batch_begin(self, xb:Tensor, yb:Tensor, train:bool=True)->Tuple[Any,Any]:
"Handle new batch `xb`,`yb` in `train` or validation."
self.state_dict.update(dict(last_input=xb, last_target=yb, train=train,
stop_epoch=False, skip_step=False, skip_zero=False, skip_bwd=False))
self('batch_begin', mets = not self.state_dict['train'])
return self.state_dict['last_input'], self.state_dict['last_target']
def on_loss_begin(self, out:Tensor)->Any:
"Handle start of loss calculation with model output `out`."
self.state_dict['last_output'] = out
self('loss_begin', call_mets=False)
return self.state_dict['last_output']
def on_backward_begin(self, loss:Tensor)->Tuple[Any,Any]:
"Handle gradient calculation on `loss`."
self.smoothener.add_value(loss.detach().cpu())
self.state_dict['last_loss'], self.state_dict['smooth_loss'] = loss, self.smoothener.smooth
self('backward_begin', call_mets=False)
return self.state_dict['last_loss'], self.state_dict['skip_bwd']
def on_backward_end(self)->Any:
"Handle end of gradient calculation."
self('backward_end', call_mets=False)
return self.state_dict['skip_step']
def on_step_end(self)->Any:
"Handle end of optimization step."
self('step_end', call_mets=False)
return self.state_dict['skip_zero']
def on_batch_end(self, loss:Tensor)->Any:
"Handle end of processing one batch with `loss`."
self.state_dict['last_loss'] = loss
self('batch_end', call_mets = not self.state_dict['train'])
if self.state_dict['train']:
self.state_dict['iteration'] += 1
self.state_dict['num_batch'] += 1
return self.state_dict['stop_epoch']
def on_epoch_end(self, val_loss:Tensor)->bool:
"Epoch is done, process `val_loss`."
self.state_dict['last_metrics'] = [val_loss] if val_loss is not None else [None]
self('epoch_end', call_mets = val_loss is not None)
self.state_dict['epoch'] += 1
return self.state_dict['stop_training']
def on_train_end(self, exception:Union[bool,Exception])->None:
"Handle end of training, `exception` is an `Exception` or False if no exceptions during training."
self('train_end', exception=exception)
@property
def skip_validate(self): return self.state_dict['skip_validate']
class AverageMetric(Callback):
"Wrap a `func` in a callback for metrics computation."
def __init__(self, func):
# If func has a __name__ use this one else it should be a partial
name = func.__name__ if hasattr(func, '__name__') else func.func.__name__
self.func, self.name = func, name
self.world = num_distrib()
def on_epoch_begin(self, **kwargs):
"Set the inner value to 0."
self.val, self.count = 0.,0
def on_batch_end(self, last_output, last_target, **kwargs):
"Update metric computation with `last_output` and `last_target`."
if not is_listy(last_target): last_target=[last_target]
self.count += first_el(last_target).size(0)
val = self.func(last_output, *last_target)
if self.world:
val = val.clone()
dist.all_reduce(val, op=dist.ReduceOp.SUM)
val /= self.world
self.val += first_el(last_target).size(0) * val.detach().cpu()
def on_epoch_end(self, last_metrics, **kwargs):
"Set the final result in `last_metrics`."
return add_metrics(last_metrics, self.val/self.count)
def annealing_no(start:Number, end:Number, pct:float)->Number:
"No annealing, always return `start`."
return start
def annealing_linear(start:Number, end:Number, pct:float)->Number:
"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
return start + pct * (end-start)
def annealing_exp(start:Number, end:Number, pct:float)->Number:
"Exponentially anneal from `start` to `end` as pct goes from 0.0 to 1.0."
return start * (end/start) ** pct
def annealing_cos(start:Number, end:Number, pct:float)->Number:
"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
cos_out = np.cos(np.pi * pct) + 1
return end + (start-end)/2 * cos_out
def do_annealing_poly(start:Number, end:Number, pct:float, degree:Number)->Number:
"Helper function for `anneal_poly`."
return end + (start-end) * (1-pct)**degree
def annealing_poly(degree:Number)->Number:
"Anneal polynomically from `start` to `end` as pct goes from 0.0 to 1.0."
return functools.partial(do_annealing_poly, degree=degree)
class Scheduler():
"Used to \"step\" from start,end (`vals`) over `n_iter` iterations on a schedule defined by `func`"
def __init__(self, vals:StartOptEnd, n_iter:int, func:Optional[AnnealFunc]=None):
self.start,self.end = (vals[0],vals[1]) if is_tuple(vals) else (vals,0)
self.n_iter = max(1,n_iter)
if func is None: self.func = annealing_linear if is_tuple(vals) else annealing_no
else: self.func = func
self.n = 0
def restart(self): self.n = 0
def step(self)->Number:
"Return next value along annealed schedule."
self.n += 1
return self.func(self.start, self.end, self.n/self.n_iter)
@property
def is_done(self)->bool:
"Return `True` if schedule completed."
return self.n >= self.n_iter