Commit
·
78ae62d
1
Parent(s):
8b2a131
Upload 5 files
Browse files- sync_batchnorm/__init__.py +12 -0
- sync_batchnorm/batchnorm.py +315 -0
- sync_batchnorm/comm.py +137 -0
- sync_batchnorm/replicate.py +94 -0
- sync_batchnorm/unittest.py +29 -0
sync_batchnorm/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : __init__.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : [email protected]
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
| 9 |
+
# Distributed under MIT License.
|
| 10 |
+
|
| 11 |
+
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
| 12 |
+
from .replicate import DataParallelWithCallback, patch_replication_callback
|
sync_batchnorm/batchnorm.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : batchnorm.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : [email protected]
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
| 9 |
+
# Distributed under MIT License.
|
| 10 |
+
|
| 11 |
+
import collections
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 17 |
+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
| 18 |
+
|
| 19 |
+
from .comm import SyncMaster
|
| 20 |
+
|
| 21 |
+
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _sum_ft(tensor):
|
| 25 |
+
"""sum over the first and last dimention"""
|
| 26 |
+
return tensor.sum(dim=0).sum(dim=-1)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _unsqueeze_ft(tensor):
|
| 30 |
+
"""add new dementions at the front and the tail"""
|
| 31 |
+
return tensor.unsqueeze(0).unsqueeze(-1)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
| 35 |
+
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class _SynchronizedBatchNorm(_BatchNorm):
|
| 39 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
|
| 40 |
+
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
| 41 |
+
|
| 42 |
+
self._sync_master = SyncMaster(self._data_parallel_master)
|
| 43 |
+
|
| 44 |
+
self._is_parallel = False
|
| 45 |
+
self._parallel_id = None
|
| 46 |
+
self._slave_pipe = None
|
| 47 |
+
|
| 48 |
+
def forward(self, input):
|
| 49 |
+
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
| 50 |
+
if not (self._is_parallel and self.training):
|
| 51 |
+
return F.batch_norm(
|
| 52 |
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
| 53 |
+
self.training, self.momentum, self.eps)
|
| 54 |
+
|
| 55 |
+
# Resize the input to (B, C, -1).
|
| 56 |
+
input_shape = input.size()
|
| 57 |
+
input = input.view(input.size(0), self.num_features, -1)
|
| 58 |
+
|
| 59 |
+
# Compute the sum and square-sum.
|
| 60 |
+
sum_size = input.size(0) * input.size(2)
|
| 61 |
+
input_sum = _sum_ft(input)
|
| 62 |
+
input_ssum = _sum_ft(input ** 2)
|
| 63 |
+
|
| 64 |
+
# Reduce-and-broadcast the statistics.
|
| 65 |
+
if self._parallel_id == 0:
|
| 66 |
+
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
| 67 |
+
else:
|
| 68 |
+
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
| 69 |
+
|
| 70 |
+
# Compute the output.
|
| 71 |
+
if self.affine:
|
| 72 |
+
# MJY:: Fuse the multiplication for speed.
|
| 73 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
| 74 |
+
else:
|
| 75 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
| 76 |
+
|
| 77 |
+
# Reshape it.
|
| 78 |
+
return output.view(input_shape)
|
| 79 |
+
|
| 80 |
+
def __data_parallel_replicate__(self, ctx, copy_id):
|
| 81 |
+
self._is_parallel = True
|
| 82 |
+
self._parallel_id = copy_id
|
| 83 |
+
|
| 84 |
+
# parallel_id == 0 means master device.
|
| 85 |
+
if self._parallel_id == 0:
|
| 86 |
+
ctx.sync_master = self._sync_master
|
| 87 |
+
else:
|
| 88 |
+
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
| 89 |
+
|
| 90 |
+
def _data_parallel_master(self, intermediates):
|
| 91 |
+
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
| 92 |
+
|
| 93 |
+
# Always using same "device order" makes the ReduceAdd operation faster.
|
| 94 |
+
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
| 95 |
+
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
| 96 |
+
|
| 97 |
+
to_reduce = [i[1][:2] for i in intermediates]
|
| 98 |
+
to_reduce = [j for i in to_reduce for j in i] # flatten
|
| 99 |
+
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
| 100 |
+
|
| 101 |
+
sum_size = sum([i[1].sum_size for i in intermediates])
|
| 102 |
+
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
| 103 |
+
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
| 104 |
+
|
| 105 |
+
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
| 106 |
+
|
| 107 |
+
outputs = []
|
| 108 |
+
for i, rec in enumerate(intermediates):
|
| 109 |
+
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
| 110 |
+
|
| 111 |
+
return outputs
|
| 112 |
+
|
| 113 |
+
def _compute_mean_std(self, sum_, ssum, size):
|
| 114 |
+
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
| 115 |
+
also maintains the moving average on the master device."""
|
| 116 |
+
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
| 117 |
+
mean = sum_ / size
|
| 118 |
+
sumvar = ssum - sum_ * mean
|
| 119 |
+
unbias_var = sumvar / (size - 1)
|
| 120 |
+
bias_var = sumvar / size
|
| 121 |
+
|
| 122 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
| 123 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
| 124 |
+
|
| 125 |
+
return mean, bias_var.clamp(self.eps) ** -0.5
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
| 129 |
+
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
| 130 |
+
mini-batch.
|
| 131 |
+
|
| 132 |
+
.. math::
|
| 133 |
+
|
| 134 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
| 135 |
+
|
| 136 |
+
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
| 137 |
+
standard-deviation are reduced across all devices during training.
|
| 138 |
+
|
| 139 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
| 140 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
| 141 |
+
the statistics only on that device, which accelerated the computation and
|
| 142 |
+
is also easy to implement, but the statistics might be inaccurate.
|
| 143 |
+
Instead, in this synchronized version, the statistics will be computed
|
| 144 |
+
over all training samples distributed on multiple devices.
|
| 145 |
+
|
| 146 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
| 147 |
+
as the built-in PyTorch implementation.
|
| 148 |
+
|
| 149 |
+
The mean and standard-deviation are calculated per-dimension over
|
| 150 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
| 151 |
+
of size C (where C is the input size).
|
| 152 |
+
|
| 153 |
+
During training, this layer keeps a running estimate of its computed mean
|
| 154 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
| 155 |
+
|
| 156 |
+
During evaluation, this running mean/variance is used for normalization.
|
| 157 |
+
|
| 158 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
| 159 |
+
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
num_features: num_features from an expected input of size
|
| 163 |
+
`batch_size x num_features [x width]`
|
| 164 |
+
eps: a value added to the denominator for numerical stability.
|
| 165 |
+
Default: 1e-5
|
| 166 |
+
momentum: the value used for the running_mean and running_var
|
| 167 |
+
computation. Default: 0.1
|
| 168 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
| 169 |
+
affine parameters. Default: ``True``
|
| 170 |
+
|
| 171 |
+
Shape:
|
| 172 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
| 173 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
| 174 |
+
|
| 175 |
+
Examples:
|
| 176 |
+
>>> # With Learnable Parameters
|
| 177 |
+
>>> m = SynchronizedBatchNorm1d(100)
|
| 178 |
+
>>> # Without Learnable Parameters
|
| 179 |
+
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
| 180 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
| 181 |
+
>>> output = m(input)
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
def _check_input_dim(self, input):
|
| 185 |
+
if input.dim() != 2 and input.dim() != 3:
|
| 186 |
+
raise ValueError('expected 2D or 3D input (got {}D input)'
|
| 187 |
+
.format(input.dim()))
|
| 188 |
+
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
| 192 |
+
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
| 193 |
+
of 3d inputs
|
| 194 |
+
|
| 195 |
+
.. math::
|
| 196 |
+
|
| 197 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
| 198 |
+
|
| 199 |
+
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
| 200 |
+
standard-deviation are reduced across all devices during training.
|
| 201 |
+
|
| 202 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
| 203 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
| 204 |
+
the statistics only on that device, which accelerated the computation and
|
| 205 |
+
is also easy to implement, but the statistics might be inaccurate.
|
| 206 |
+
Instead, in this synchronized version, the statistics will be computed
|
| 207 |
+
over all training samples distributed on multiple devices.
|
| 208 |
+
|
| 209 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
| 210 |
+
as the built-in PyTorch implementation.
|
| 211 |
+
|
| 212 |
+
The mean and standard-deviation are calculated per-dimension over
|
| 213 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
| 214 |
+
of size C (where C is the input size).
|
| 215 |
+
|
| 216 |
+
During training, this layer keeps a running estimate of its computed mean
|
| 217 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
| 218 |
+
|
| 219 |
+
During evaluation, this running mean/variance is used for normalization.
|
| 220 |
+
|
| 221 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
| 222 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
num_features: num_features from an expected input of
|
| 226 |
+
size batch_size x num_features x height x width
|
| 227 |
+
eps: a value added to the denominator for numerical stability.
|
| 228 |
+
Default: 1e-5
|
| 229 |
+
momentum: the value used for the running_mean and running_var
|
| 230 |
+
computation. Default: 0.1
|
| 231 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
| 232 |
+
affine parameters. Default: ``True``
|
| 233 |
+
|
| 234 |
+
Shape:
|
| 235 |
+
- Input: :math:`(N, C, H, W)`
|
| 236 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
| 237 |
+
|
| 238 |
+
Examples:
|
| 239 |
+
>>> # With Learnable Parameters
|
| 240 |
+
>>> m = SynchronizedBatchNorm2d(100)
|
| 241 |
+
>>> # Without Learnable Parameters
|
| 242 |
+
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
| 243 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
| 244 |
+
>>> output = m(input)
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
def _check_input_dim(self, input):
|
| 248 |
+
if input.dim() != 4:
|
| 249 |
+
raise ValueError('expected 4D input (got {}D input)'
|
| 250 |
+
.format(input.dim()))
|
| 251 |
+
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
| 255 |
+
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
| 256 |
+
of 4d inputs
|
| 257 |
+
|
| 258 |
+
.. math::
|
| 259 |
+
|
| 260 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
| 261 |
+
|
| 262 |
+
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
| 263 |
+
standard-deviation are reduced across all devices during training.
|
| 264 |
+
|
| 265 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
| 266 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
| 267 |
+
the statistics only on that device, which accelerated the computation and
|
| 268 |
+
is also easy to implement, but the statistics might be inaccurate.
|
| 269 |
+
Instead, in this synchronized version, the statistics will be computed
|
| 270 |
+
over all training samples distributed on multiple devices.
|
| 271 |
+
|
| 272 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
| 273 |
+
as the built-in PyTorch implementation.
|
| 274 |
+
|
| 275 |
+
The mean and standard-deviation are calculated per-dimension over
|
| 276 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
| 277 |
+
of size C (where C is the input size).
|
| 278 |
+
|
| 279 |
+
During training, this layer keeps a running estimate of its computed mean
|
| 280 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
| 281 |
+
|
| 282 |
+
During evaluation, this running mean/variance is used for normalization.
|
| 283 |
+
|
| 284 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
| 285 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
| 286 |
+
or Spatio-temporal BatchNorm
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
num_features: num_features from an expected input of
|
| 290 |
+
size batch_size x num_features x depth x height x width
|
| 291 |
+
eps: a value added to the denominator for numerical stability.
|
| 292 |
+
Default: 1e-5
|
| 293 |
+
momentum: the value used for the running_mean and running_var
|
| 294 |
+
computation. Default: 0.1
|
| 295 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
| 296 |
+
affine parameters. Default: ``True``
|
| 297 |
+
|
| 298 |
+
Shape:
|
| 299 |
+
- Input: :math:`(N, C, D, H, W)`
|
| 300 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
| 301 |
+
|
| 302 |
+
Examples:
|
| 303 |
+
>>> # With Learnable Parameters
|
| 304 |
+
>>> m = SynchronizedBatchNorm3d(100)
|
| 305 |
+
>>> # Without Learnable Parameters
|
| 306 |
+
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
| 307 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
| 308 |
+
>>> output = m(input)
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
def _check_input_dim(self, input):
|
| 312 |
+
if input.dim() != 5:
|
| 313 |
+
raise ValueError('expected 5D input (got {}D input)'
|
| 314 |
+
.format(input.dim()))
|
| 315 |
+
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
sync_batchnorm/comm.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : comm.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : [email protected]
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
| 9 |
+
# Distributed under MIT License.
|
| 10 |
+
|
| 11 |
+
import queue
|
| 12 |
+
import collections
|
| 13 |
+
import threading
|
| 14 |
+
|
| 15 |
+
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class FutureResult(object):
|
| 19 |
+
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self._result = None
|
| 23 |
+
self._lock = threading.Lock()
|
| 24 |
+
self._cond = threading.Condition(self._lock)
|
| 25 |
+
|
| 26 |
+
def put(self, result):
|
| 27 |
+
with self._lock:
|
| 28 |
+
assert self._result is None, 'Previous result has\'t been fetched.'
|
| 29 |
+
self._result = result
|
| 30 |
+
self._cond.notify()
|
| 31 |
+
|
| 32 |
+
def get(self):
|
| 33 |
+
with self._lock:
|
| 34 |
+
if self._result is None:
|
| 35 |
+
self._cond.wait()
|
| 36 |
+
|
| 37 |
+
res = self._result
|
| 38 |
+
self._result = None
|
| 39 |
+
return res
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
| 43 |
+
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SlavePipe(_SlavePipeBase):
|
| 47 |
+
"""Pipe for master-slave communication."""
|
| 48 |
+
|
| 49 |
+
def run_slave(self, msg):
|
| 50 |
+
self.queue.put((self.identifier, msg))
|
| 51 |
+
ret = self.result.get()
|
| 52 |
+
self.queue.put(True)
|
| 53 |
+
return ret
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class SyncMaster(object):
|
| 57 |
+
"""An abstract `SyncMaster` object.
|
| 58 |
+
|
| 59 |
+
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
| 60 |
+
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
| 61 |
+
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
| 62 |
+
and passed to a registered callback.
|
| 63 |
+
- After receiving the messages, the master device should gather the information and determine to message passed
|
| 64 |
+
back to each slave devices.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, master_callback):
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
master_callback: a callback to be invoked after having collected messages from slave devices.
|
| 72 |
+
"""
|
| 73 |
+
self._master_callback = master_callback
|
| 74 |
+
self._queue = queue.Queue()
|
| 75 |
+
self._registry = collections.OrderedDict()
|
| 76 |
+
self._activated = False
|
| 77 |
+
|
| 78 |
+
def __getstate__(self):
|
| 79 |
+
return {'master_callback': self._master_callback}
|
| 80 |
+
|
| 81 |
+
def __setstate__(self, state):
|
| 82 |
+
self.__init__(state['master_callback'])
|
| 83 |
+
|
| 84 |
+
def register_slave(self, identifier):
|
| 85 |
+
"""
|
| 86 |
+
Register an slave device.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
identifier: an identifier, usually is the device id.
|
| 90 |
+
|
| 91 |
+
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
| 92 |
+
|
| 93 |
+
"""
|
| 94 |
+
if self._activated:
|
| 95 |
+
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
| 96 |
+
self._activated = False
|
| 97 |
+
self._registry.clear()
|
| 98 |
+
future = FutureResult()
|
| 99 |
+
self._registry[identifier] = _MasterRegistry(future)
|
| 100 |
+
return SlavePipe(identifier, self._queue, future)
|
| 101 |
+
|
| 102 |
+
def run_master(self, master_msg):
|
| 103 |
+
"""
|
| 104 |
+
Main entry for the master device in each forward pass.
|
| 105 |
+
The messages were first collected from each devices (including the master device), and then
|
| 106 |
+
an callback will be invoked to compute the message to be sent back to each devices
|
| 107 |
+
(including the master device).
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
master_msg: the message that the master want to send to itself. This will be placed as the first
|
| 111 |
+
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
| 112 |
+
|
| 113 |
+
Returns: the message to be sent back to the master device.
|
| 114 |
+
|
| 115 |
+
"""
|
| 116 |
+
self._activated = True
|
| 117 |
+
|
| 118 |
+
intermediates = [(0, master_msg)]
|
| 119 |
+
for i in range(self.nr_slaves):
|
| 120 |
+
intermediates.append(self._queue.get())
|
| 121 |
+
|
| 122 |
+
results = self._master_callback(intermediates)
|
| 123 |
+
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
| 124 |
+
|
| 125 |
+
for i, res in results:
|
| 126 |
+
if i == 0:
|
| 127 |
+
continue
|
| 128 |
+
self._registry[i].result.put(res)
|
| 129 |
+
|
| 130 |
+
for i in range(self.nr_slaves):
|
| 131 |
+
assert self._queue.get() is True
|
| 132 |
+
|
| 133 |
+
return results[0][1]
|
| 134 |
+
|
| 135 |
+
@property
|
| 136 |
+
def nr_slaves(self):
|
| 137 |
+
return len(self._registry)
|
sync_batchnorm/replicate.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : replicate.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : [email protected]
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
| 9 |
+
# Distributed under MIT License.
|
| 10 |
+
|
| 11 |
+
import functools
|
| 12 |
+
|
| 13 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
'CallbackContext',
|
| 17 |
+
'execute_replication_callbacks',
|
| 18 |
+
'DataParallelWithCallback',
|
| 19 |
+
'patch_replication_callback'
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CallbackContext(object):
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def execute_replication_callbacks(modules):
|
| 28 |
+
"""
|
| 29 |
+
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
| 30 |
+
|
| 31 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
| 32 |
+
|
| 33 |
+
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
| 34 |
+
(shared among multiple copies of this module on different devices).
|
| 35 |
+
Through this context, different copies can share some information.
|
| 36 |
+
|
| 37 |
+
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
| 38 |
+
of any slave copies.
|
| 39 |
+
"""
|
| 40 |
+
master_copy = modules[0]
|
| 41 |
+
nr_modules = len(list(master_copy.modules()))
|
| 42 |
+
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
| 43 |
+
|
| 44 |
+
for i, module in enumerate(modules):
|
| 45 |
+
for j, m in enumerate(module.modules()):
|
| 46 |
+
if hasattr(m, '__data_parallel_replicate__'):
|
| 47 |
+
m.__data_parallel_replicate__(ctxs[j], i)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class DataParallelWithCallback(DataParallel):
|
| 51 |
+
"""
|
| 52 |
+
Data Parallel with a replication callback.
|
| 53 |
+
|
| 54 |
+
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
| 55 |
+
original `replicate` function.
|
| 56 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
| 57 |
+
|
| 58 |
+
Examples:
|
| 59 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
| 60 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
| 61 |
+
# sync_bn.__data_parallel_replicate__ will be invoked.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def replicate(self, module, device_ids):
|
| 65 |
+
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
| 66 |
+
execute_replication_callbacks(modules)
|
| 67 |
+
return modules
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def patch_replication_callback(data_parallel):
|
| 71 |
+
"""
|
| 72 |
+
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
| 73 |
+
Useful when you have customized `DataParallel` implementation.
|
| 74 |
+
|
| 75 |
+
Examples:
|
| 76 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
| 77 |
+
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
| 78 |
+
> patch_replication_callback(sync_bn)
|
| 79 |
+
# this is equivalent to
|
| 80 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
| 81 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
assert isinstance(data_parallel, DataParallel)
|
| 85 |
+
|
| 86 |
+
old_replicate = data_parallel.replicate
|
| 87 |
+
|
| 88 |
+
@functools.wraps(old_replicate)
|
| 89 |
+
def new_replicate(module, device_ids):
|
| 90 |
+
modules = old_replicate(module, device_ids)
|
| 91 |
+
execute_replication_callbacks(modules)
|
| 92 |
+
return modules
|
| 93 |
+
|
| 94 |
+
data_parallel.replicate = new_replicate
|
sync_batchnorm/unittest.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : unittest.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : [email protected]
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
| 9 |
+
# Distributed under MIT License.
|
| 10 |
+
|
| 11 |
+
import unittest
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
from torch.autograd import Variable
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def as_numpy(v):
|
| 18 |
+
if isinstance(v, Variable):
|
| 19 |
+
v = v.data
|
| 20 |
+
return v.cpu().numpy()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TorchTestCase(unittest.TestCase):
|
| 24 |
+
def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
|
| 25 |
+
npa, npb = as_numpy(a), as_numpy(b)
|
| 26 |
+
self.assertTrue(
|
| 27 |
+
np.allclose(npa, npb, atol=atol),
|
| 28 |
+
'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
|
| 29 |
+
)
|