|
from .torch_core import * |
|
from torch.optim import Optimizer |
|
import types |
|
|
|
__all__ = ['StatScope', 'Statistic', 'ConstStatistic', 'AvgStatistic', 'AvgSquare', 'GeneralOptimizer'] |
|
|
|
StatScope = Enum('StatScope', 'Global Group Layer Channel Weight') |
|
|
|
@dataclass |
|
class Statistic(): |
|
name:str |
|
param:float=0.9 |
|
scope:StatScope=StatScope.Weight |
|
init:float=0. |
|
|
|
@property |
|
def buf(self): return f'{self.name}_buffer' |
|
|
|
def new_step(self): |
|
"Set state when computing statistics for Global or Group" |
|
raise NotImplementedError |
|
|
|
def accumulate(self, val): |
|
"Add `val` to statistic" |
|
raise NotImplementedError |
|
|
|
def update(self, state, param, val=None, step=None): |
|
"Update state with accumlated, or `val` (if `Weight` or `Layer` scope)" |
|
raise NotImplementedError |
|
|
|
class ConstStatistic(Statistic): |
|
@property |
|
def buf(self): return None |
|
def new_step(self): pass |
|
def accumulate(self): pass |
|
def update(self, state, param, val=None, step=None): return param |
|
|
|
@dataclass |
|
class CounterStat(Statistic): |
|
def __post_init__(self): self.init,self._buf,self.name = 0,self.name,None |
|
@property |
|
def buf(self): return self._buf |
|
def new_step(self): pass |
|
def accumulate(self, val): pass |
|
def update(self, state, param, val=None, step=None): return state + 1 |
|
|
|
@dataclass |
|
class AvgStatistic(Statistic): |
|
decay:bool=False |
|
debias:bool=False |
|
def new_step(self): self.val,self.count = 0.,0 |
|
|
|
def accumulate(self, val): |
|
self.count += 1 |
|
self.val += self._get_val1(val) |
|
|
|
def _get_val1(self, val): return val.mean() |
|
def _get_val2(self, state, val, param): return state.add_(1-param, val) if self.decay else state.add_(val) |
|
def _get_val3(self, state, val, param): |
|
v = val.view(val.size(0), -1).mean(1) |
|
return state.add_(1-param, v) if self.decay else state.add_(v) |
|
|
|
def update(self, state, param, val=None, step=None): |
|
if self.scope == StatScope.Weight: |
|
|
|
res = self._get_val2(state.mul_(param), val, param) |
|
elif self.scope == StatScope.Channel: |
|
|
|
res = self._get_val3(state.mul_(param), val, param) |
|
|
|
elif self.scope == StatScope.Layer: res = state*param + self._get_val1(val) * (1-param if self.decay else 1.) |
|
elif self.count != 0: res = state*param + self.val/self.count * (1-param if self.decay else 1.) |
|
else: return state |
|
if self.debias and step is not None: res /= (1 - param ** step) |
|
return res |
|
|
|
class AvgSquare(AvgStatistic): |
|
|
|
def __init__(self, name:str, param:float=0.9, scope=StatScope.Weight, init:float=0., decay:bool=True, debias:bool=False): |
|
super().__init__(name, param=param, scope=scope, init=init, decay=decay, debias=debias) |
|
|
|
def _get_val1(self, val): return torch.norm(val).pow(2)/val.numel() |
|
def _get_val2(self, state, val, param): |
|
return state.addcmul_(1-param, val, val) if self.decay else state.addcmul_(val, val) |
|
def _get_val3(self, state, val, param): |
|
v = val.view(val.size(0), -1).mean(1) |
|
return state.addcmul_(1-param, v, v) if self.decay else state.addcmul_(v, v) |
|
|
|
class GeneralOptimizer(Optimizer): |
|
def __init__(self, params, stats=None, on_step:Callable=None): |
|
defaults = {s.name:s.param for s in listify(stats) if s.name is not None} |
|
super().__init__(params, defaults) |
|
self.global_stats,self.group_stats,self.layer_stats,self.channel_stats,self.weight_stats = self._split_stats(stats) |
|
self.init_stats() |
|
if on_step is not None: self.on_step = types.MethodType(on_step, self) |
|
|
|
def step(self, closure=None): |
|
self.update_stats() |
|
for i,pg in enumerate(self.param_groups): |
|
for p in pg['params']: |
|
if p.grad is not None: self.on_step(p, pg, i) |
|
|
|
def on_step(self, p, group, group_idx): p.data.add_(-group['lr'], p.grad.data) |
|
|
|
def _split_stats(self, stats): |
|
splits = [[stat for stat in listify(stats) if stat.scope==scope] for scope in StatScope] |
|
for split,s in zip([splits[0], splits[1], splits[2]+splits[3]+splits[4]], StatScope): |
|
if np.any([getattr(s, 'debias', False) for s in split]): split.insert(0, CounterStat('step', scope=s)) |
|
return splits |
|
|
|
def _init_stats(self, stats, data=None): |
|
return {stat.buf: stat.init if data is None |
|
else torch.zeros_like(data) + stat.init for stat in stats if stat.buf is not None} |
|
|
|
def init_stats(self): |
|
self.state['global'] = self._init_stats(self.global_stats) |
|
for i,pg in enumerate(self.param_groups): |
|
self.state[f'group{i}'] = self._init_stats(self.group_stats) |
|
for p in pg['params']: |
|
self.state[p] = self._init_stats(self.layer_stats) |
|
self.state[p].update(self._init_stats(self.channel_stats, p.data.view(p.data.size(0), -1).mean(1))) |
|
self.state[p].update(self._init_stats(self.weight_stats, p.data)) |
|
|
|
def _set_bufs(self, p, stats, pg, val=None): |
|
d = self.state[p] |
|
for stat in stats: |
|
if stat.buf is not None: d[stat.buf] = stat.update(d[stat.buf], pg[stat.name], val=val, step=d.get('step', None)) |
|
|
|
def update_stats(self): |
|
for stat in self.global_stats: stat.new_step() |
|
for i,pg in enumerate(self.param_groups): |
|
for stat in self.group_stats: stat.new_step() |
|
for p in pg['params']: |
|
if p.grad is not None: |
|
for stat in self.global_stats + self.group_stats: stat.accumulate(p.grad.data) |
|
self._set_bufs(p, self.layer_stats+self.channel_stats+self.weight_stats, pg, p.grad.data) |
|
self._set_bufs(f'group{i}', self.group_stats, pg) |
|
self._set_bufs('global', self.global_stats, self.param_groups[0]) |
|
|
|
|