Skip to content

Commit

Permalink
add use_global_stats in nn.BatchNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
tornadomeet authored and piiswrong committed Jan 22, 2018
1 parent ab0d1d5 commit dae6cda
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,10 @@ class BatchNorm(HybridBlock):
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
use_global_stats: bool, default False
If True, use global moving statistics instead of local batch-norm. This will force
change batch-norm into a scale shift operator.
If False, use local batch-norm.
beta_initializer: str or `Initializer`, default 'zeros'
Initializer for the beta weight.
gamma_initializer: str or `Initializer`, default 'ones'
Expand All @@ -329,12 +333,12 @@ class BatchNorm(HybridBlock):
- **out**: output tensor with the same shape as `data`.
"""
def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
beta_initializer='zeros', gamma_initializer='ones',
use_global_stats=False, beta_initializer='zeros', gamma_initializer='ones',
running_mean_initializer='zeros', running_variance_initializer='ones',
in_channels=0, **kwargs):
super(BatchNorm, self).__init__(**kwargs)
self._kwargs = {'axis': axis, 'eps': epsilon, 'momentum': momentum,
'fix_gamma': not scale}
'fix_gamma': not scale, 'use_global_stats': use_global_stats}
if in_channels != 0:
self.in_channels = in_channels

Expand Down

0 comments on commit dae6cda

Please sign in to comment.