from ..core import * from ..callback import * from ..basic_train import Learner, LearnerCallback __all__ = ['GeneralScheduler', 'TrainingPhase'] @dataclass class TrainingPhase(): "Schedule hyper-parameters for a phase of `length` iterations." length:int def __post_init__(self): self.scheds = dict() def schedule_hp(self, name, vals, anneal=None): "Adds a schedule for `name` between `vals` using `anneal`." self.scheds[name] = Scheduler(vals, self.length, anneal) return self class GeneralScheduler(LearnerCallback): "Schedule multiple `TrainingPhase` for a `Learner`." def __init__(self, learn:Learner, phases:Collection[TrainingPhase], start_epoch:int=None): super().__init__(learn) self.phases,self.start_epoch = phases,start_epoch def on_train_begin(self, epoch:int, **kwargs:Any)->None: "Initialize the schedulers for training." res = {'epoch':self.start_epoch} if self.start_epoch is not None else None self.start_epoch = ifnone(self.start_epoch, epoch) self.scheds = [p.scheds for p in self.phases] self.opt = self.learn.opt for k,v in self.scheds[0].items(): v.restart() self.opt.set_stat(k, v.start) self.idx_s = 0 return res def jump_to_epoch(self, epoch:int)->None: for _ in range(len(self.learn.data.train_dl) * epoch): self.on_batch_end(True) def on_batch_end(self, train, **kwargs:Any)->None: "Take a step in lr,mom sched, start next stepper when the current one is complete." if train: if self.idx_s >= len(self.scheds): return {'stop_training': True, 'stop_epoch': True} sched = self.scheds[self.idx_s] for k,v in sched.items(): self.opt.set_stat(k, v.step()) if list(sched.values())[0].is_done: self.idx_s += 1