File size: 6,028 Bytes
cc9dfd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
from .torch_core import *
from torch.optim import Optimizer
import types
__all__ = ['StatScope', 'Statistic', 'ConstStatistic', 'AvgStatistic', 'AvgSquare', 'GeneralOptimizer']
StatScope = Enum('StatScope', 'Global Group Layer Channel Weight')
@dataclass
class Statistic():
name:str
param:float=0.9 # e.g. for exp moving average
scope:StatScope=StatScope.Weight
init:float=0. # starting value
@property
def buf(self): return f'{self.name}_buffer'
def new_step(self):
"Set state when computing statistics for Global or Group"
raise NotImplementedError
def accumulate(self, val):
"Add `val` to statistic"
raise NotImplementedError
def update(self, state, param, val=None, step=None):
"Update state with accumlated, or `val` (if `Weight` or `Layer` scope)"
raise NotImplementedError
class ConstStatistic(Statistic):
@property
def buf(self): return None
def new_step(self): pass
def accumulate(self): pass
def update(self, state, param, val=None, step=None): return param
@dataclass
class CounterStat(Statistic):
def __post_init__(self): self.init,self._buf,self.name = 0,self.name,None
@property
def buf(self): return self._buf
def new_step(self): pass
def accumulate(self, val): pass
def update(self, state, param, val=None, step=None): return state + 1
@dataclass
class AvgStatistic(Statistic):
decay:bool=False
debias:bool=False
def new_step(self): self.val,self.count = 0.,0
def accumulate(self, val):
self.count += 1
self.val += self._get_val1(val)
def _get_val1(self, val): return val.mean()
def _get_val2(self, state, val, param): return state.add_(1-param, val) if self.decay else state.add_(val)
def _get_val3(self, state, val, param):
v = val.view(val.size(0), -1).mean(1)
return state.add_(1-param, v) if self.decay else state.add_(v)
def update(self, state, param, val=None, step=None):
if self.scope == StatScope.Weight:
# `state` is a tensor
res = self._get_val2(state.mul_(param), val, param)
elif self.scope == StatScope.Channel:
# `state` is a tensor of size n_channels
res = self._get_val3(state.mul_(param), val, param)
# For everything else, `state` is a scalar
elif self.scope == StatScope.Layer: res = state*param + self._get_val1(val) * (1-param if self.decay else 1.)
elif self.count != 0: res = state*param + self.val/self.count * (1-param if self.decay else 1.)
else: return state
if self.debias and step is not None: res /= (1 - param ** step)
return res
class AvgSquare(AvgStatistic):
def __init__(self, name:str, param:float=0.9, scope=StatScope.Weight, init:float=0., decay:bool=True, debias:bool=False):
super().__init__(name, param=param, scope=scope, init=init, decay=decay, debias=debias)
def _get_val1(self, val): return torch.norm(val).pow(2)/val.numel()
def _get_val2(self, state, val, param):
return state.addcmul_(1-param, val, val) if self.decay else state.addcmul_(val, val)
def _get_val3(self, state, val, param):
v = val.view(val.size(0), -1).mean(1)
return state.addcmul_(1-param, v, v) if self.decay else state.addcmul_(v, v)
class GeneralOptimizer(Optimizer):
def __init__(self, params, stats=None, on_step:Callable=None):
defaults = {s.name:s.param for s in listify(stats) if s.name is not None}
super().__init__(params, defaults)
self.global_stats,self.group_stats,self.layer_stats,self.channel_stats,self.weight_stats = self._split_stats(stats)
self.init_stats()
if on_step is not None: self.on_step = types.MethodType(on_step, self)
def step(self, closure=None):
self.update_stats()
for i,pg in enumerate(self.param_groups):
for p in pg['params']:
if p.grad is not None: self.on_step(p, pg, i)
def on_step(self, p, group, group_idx): p.data.add_(-group['lr'], p.grad.data)
def _split_stats(self, stats):
splits = [[stat for stat in listify(stats) if stat.scope==scope] for scope in StatScope]
for split,s in zip([splits[0], splits[1], splits[2]+splits[3]+splits[4]], StatScope):
if np.any([getattr(s, 'debias', False) for s in split]): split.insert(0, CounterStat('step', scope=s))
return splits
def _init_stats(self, stats, data=None):
return {stat.buf: stat.init if data is None
else torch.zeros_like(data) + stat.init for stat in stats if stat.buf is not None}
def init_stats(self):
self.state['global'] = self._init_stats(self.global_stats)
for i,pg in enumerate(self.param_groups):
self.state[f'group{i}'] = self._init_stats(self.group_stats)
for p in pg['params']:
self.state[p] = self._init_stats(self.layer_stats)
self.state[p].update(self._init_stats(self.channel_stats, p.data.view(p.data.size(0), -1).mean(1)))
self.state[p].update(self._init_stats(self.weight_stats, p.data))
def _set_bufs(self, p, stats, pg, val=None):
d = self.state[p]
for stat in stats:
if stat.buf is not None: d[stat.buf] = stat.update(d[stat.buf], pg[stat.name], val=val, step=d.get('step', None))
def update_stats(self):
for stat in self.global_stats: stat.new_step()
for i,pg in enumerate(self.param_groups):
for stat in self.group_stats: stat.new_step()
for p in pg['params']:
if p.grad is not None:
for stat in self.global_stats + self.group_stats: stat.accumulate(p.grad.data)
self._set_bufs(p, self.layer_stats+self.channel_stats+self.weight_stats, pg, p.grad.data)
self._set_bufs(f'group{i}', self.group_stats, pg)
self._set_bufs('global', self.global_stats, self.param_groups[0])
|