Image-Colorizer / DeOldify /fastai /general_optimizer.py
sayed99's picture
project upload
cc9dfd7
from .torch_core import *
from torch.optim import Optimizer
import types
__all__ = ['StatScope', 'Statistic', 'ConstStatistic', 'AvgStatistic', 'AvgSquare', 'GeneralOptimizer']
StatScope = Enum('StatScope', 'Global Group Layer Channel Weight')
@dataclass
class Statistic():
name:str
param:float=0.9 # e.g. for exp moving average
scope:StatScope=StatScope.Weight
init:float=0. # starting value
@property
def buf(self): return f'{self.name}_buffer'
def new_step(self):
"Set state when computing statistics for Global or Group"
raise NotImplementedError
def accumulate(self, val):
"Add `val` to statistic"
raise NotImplementedError
def update(self, state, param, val=None, step=None):
"Update state with accumlated, or `val` (if `Weight` or `Layer` scope)"
raise NotImplementedError
class ConstStatistic(Statistic):
@property
def buf(self): return None
def new_step(self): pass
def accumulate(self): pass
def update(self, state, param, val=None, step=None): return param
@dataclass
class CounterStat(Statistic):
def __post_init__(self): self.init,self._buf,self.name = 0,self.name,None
@property
def buf(self): return self._buf
def new_step(self): pass
def accumulate(self, val): pass
def update(self, state, param, val=None, step=None): return state + 1
@dataclass
class AvgStatistic(Statistic):
decay:bool=False
debias:bool=False
def new_step(self): self.val,self.count = 0.,0
def accumulate(self, val):
self.count += 1
self.val += self._get_val1(val)
def _get_val1(self, val): return val.mean()
def _get_val2(self, state, val, param): return state.add_(1-param, val) if self.decay else state.add_(val)
def _get_val3(self, state, val, param):
v = val.view(val.size(0), -1).mean(1)
return state.add_(1-param, v) if self.decay else state.add_(v)
def update(self, state, param, val=None, step=None):
if self.scope == StatScope.Weight:
# `state` is a tensor
res = self._get_val2(state.mul_(param), val, param)
elif self.scope == StatScope.Channel:
# `state` is a tensor of size n_channels
res = self._get_val3(state.mul_(param), val, param)
# For everything else, `state` is a scalar
elif self.scope == StatScope.Layer: res = state*param + self._get_val1(val) * (1-param if self.decay else 1.)
elif self.count != 0: res = state*param + self.val/self.count * (1-param if self.decay else 1.)
else: return state
if self.debias and step is not None: res /= (1 - param ** step)
return res
class AvgSquare(AvgStatistic):
def __init__(self, name:str, param:float=0.9, scope=StatScope.Weight, init:float=0., decay:bool=True, debias:bool=False):
super().__init__(name, param=param, scope=scope, init=init, decay=decay, debias=debias)
def _get_val1(self, val): return torch.norm(val).pow(2)/val.numel()
def _get_val2(self, state, val, param):
return state.addcmul_(1-param, val, val) if self.decay else state.addcmul_(val, val)
def _get_val3(self, state, val, param):
v = val.view(val.size(0), -1).mean(1)
return state.addcmul_(1-param, v, v) if self.decay else state.addcmul_(v, v)
class GeneralOptimizer(Optimizer):
def __init__(self, params, stats=None, on_step:Callable=None):
defaults = {s.name:s.param for s in listify(stats) if s.name is not None}
super().__init__(params, defaults)
self.global_stats,self.group_stats,self.layer_stats,self.channel_stats,self.weight_stats = self._split_stats(stats)
self.init_stats()
if on_step is not None: self.on_step = types.MethodType(on_step, self)
def step(self, closure=None):
self.update_stats()
for i,pg in enumerate(self.param_groups):
for p in pg['params']:
if p.grad is not None: self.on_step(p, pg, i)
def on_step(self, p, group, group_idx): p.data.add_(-group['lr'], p.grad.data)
def _split_stats(self, stats):
splits = [[stat for stat in listify(stats) if stat.scope==scope] for scope in StatScope]
for split,s in zip([splits[0], splits[1], splits[2]+splits[3]+splits[4]], StatScope):
if np.any([getattr(s, 'debias', False) for s in split]): split.insert(0, CounterStat('step', scope=s))
return splits
def _init_stats(self, stats, data=None):
return {stat.buf: stat.init if data is None
else torch.zeros_like(data) + stat.init for stat in stats if stat.buf is not None}
def init_stats(self):
self.state['global'] = self._init_stats(self.global_stats)
for i,pg in enumerate(self.param_groups):
self.state[f'group{i}'] = self._init_stats(self.group_stats)
for p in pg['params']:
self.state[p] = self._init_stats(self.layer_stats)
self.state[p].update(self._init_stats(self.channel_stats, p.data.view(p.data.size(0), -1).mean(1)))
self.state[p].update(self._init_stats(self.weight_stats, p.data))
def _set_bufs(self, p, stats, pg, val=None):
d = self.state[p]
for stat in stats:
if stat.buf is not None: d[stat.buf] = stat.update(d[stat.buf], pg[stat.name], val=val, step=d.get('step', None))
def update_stats(self):
for stat in self.global_stats: stat.new_step()
for i,pg in enumerate(self.param_groups):
for stat in self.group_stats: stat.new_step()
for p in pg['params']:
if p.grad is not None:
for stat in self.global_stats + self.group_stats: stat.accumulate(p.grad.data)
self._set_bufs(p, self.layer_stats+self.channel_stats+self.weight_stats, pg, p.grad.data)
self._set_bufs(f'group{i}', self.group_stats, pg)
self._set_bufs('global', self.global_stats, self.param_groups[0])