Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,5 +222,60 @@ def test_3d(self):
self.assertEqual(np.allclose(y1.numpy(), y2.numpy()), True)


class TestBatchNormUseGlobalStats(unittest.TestCase):
def setUp(self):
self.places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
self.places.append(fluid.CUDAPlace(0))
self.init_test()

### train mode
def init_test(self):
self.use_global_stats = True
self.trainable_statistics = False

def test_global_stats(self):
for p in self.places:
with fluid.dygraph.guard(p):
x = paddle.randn([2, 6, 6, 4])
net1 = paddle.fluid.dygraph.BatchNorm(
6,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.0)),
use_global_stats=self.use_global_stats,
trainable_statistics=self.trainable_statistics)
net2 = paddle.nn.BatchNorm2D(
6, use_global_stats=self.use_global_stats)
net2.weight = net1.weight
net2.bias = net1.bias
if self.trainable_statistics == True:
net1.training = False
net2.training = False
y1 = net1(x)
y2 = net2(x)
self.assertEqual(np.allclose(y1.numpy(), y2.numpy()), True)


class TestBatchNormUseGlobalStatsCase1(TestBatchNormUseGlobalStats):
### test mode
def init_test(self):
self.use_global_stats = False
self.trainable_statistics = True


class TestBatchNormUseGlobalStatsCase2(TestBatchNormUseGlobalStats):
### train mode
def init_test(self):
self.use_global_stats = False
self.trainable_statistics = False


class TestBatchNormUseGlobalStatsCase3(TestBatchNormUseGlobalStats):
### test mode
def init_test(self):
self.use_global_stats = True
self.trainable_statistics = True


if __name__ == '__main__':
unittest.main()
14 changes: 11 additions & 3 deletions python/paddle/nn/functional/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def batch_norm(x,
momentum=0.9,
epsilon=1e-05,
data_format="NCHW",
use_global_stats=None,
name=None):
"""
Applies Batch Normalization as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .
Expand All @@ -139,6 +140,7 @@ def batch_norm(x,
momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
training(bool, optional): True means train mode which compute by batch data and track global mean and var during train period. False means inference mode which compute by global mean and var which calculated by train period. Defalut False.
data_format(str, optional): Specify the input data format, may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Defalut "NCHW".
use_global_stats(bool|None, optional): Whether to use global mean and variance. If set to False, use the statistics of one mini-batch, if set to True, use the global statistics, if set to None, use global statistics in the test phase and use the statistics of one mini-batch in the training phase. Default: None.
name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..

Returns:
Expand Down Expand Up @@ -167,8 +169,6 @@ def batch_norm(x,

assert len(x.shape) >= 2, "input dim must be larger than 1"

# we use not training means use_global_status, more details see nn._BatchNormBase
use_global_stats = not training
# input ad out must share the memory
mean_out = running_mean
variance_out = running_var
Expand All @@ -181,11 +181,18 @@ def batch_norm(x,

data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC'

if use_global_stats == None:
use_global_stats = not training
trainable_statistics = False
else:
trainable_statistics = not use_global_stats

if in_dygraph_mode():
# for dygraph need tuple
attrs = ("momentum", momentum, "epsilon", epsilon, "data_layout",
data_format, "use_mkldnn", False, "fuse_with_relu", False,
"use_global_stats", use_global_stats)
"use_global_stats", use_global_stats, "trainable_statistics",
trainable_statistics)
batch_norm_out, _, _, _, _, _ = core.ops.batch_norm(
x, weight, bias, running_mean, running_var, mean_out, variance_out,
*attrs)
Expand All @@ -204,6 +211,7 @@ def batch_norm(x,
"use_mkldnn": False,
"fuse_with_relu": False,
"use_global_stats": use_global_stats,
"trainable_statistics": trainable_statistics,
}

inputs = {
Expand Down
8 changes: 7 additions & 1 deletion python/paddle/nn/layer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,11 +550,13 @@ def __init__(self,
weight_attr=None,
bias_attr=None,
data_format='NCHW',
use_global_stats=None,
name=None):
super(_BatchNormBase, self).__init__()
self._num_features = num_features
self._weight_attr = weight_attr
self._bias_attr = bias_attr
self._use_global_stats = use_global_stats

if get_default_dtype() == 'float16':
set_default_dtype('float32')
Expand Down Expand Up @@ -642,7 +644,8 @@ def forward(self, input):
training=self.training,
momentum=self._momentum,
epsilon=self._epsilon,
data_format=self._data_format)
data_format=self._data_format,
use_global_stats=self._use_global_stats)


class BatchNorm1D(_BatchNormBase):
Expand Down Expand Up @@ -694,6 +697,7 @@ class BatchNorm1D(_BatchNormBase):
will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable.
If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None.
data_format(str, optional): Specify the input data format, may be "NC", "NCL" or "NLC". Defalut "NCL".
use_global_stats(bool|None, optional): Whether to use global mean and variance. If set to False, use the statistics of one mini-batch, if set to True, use the global statistics, if set to None, use global statistics in the test phase and use the statistics of one mini-batch in the training phase. Default: None.
name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..

Shape:
Expand Down Expand Up @@ -784,6 +788,7 @@ class BatchNorm2D(_BatchNormBase):
will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable.
If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None.
data_format(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC". Default: NCHW.
use_global_stats(bool|None, optional): Whether to use global mean and variance. If set to False, use the statistics of one mini-batch, if set to True, use the global statistics, if set to None, use global statistics in the test phase and use the statistics of one mini-batch in the training phase. Default: None.
name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..

Shape:
Expand Down Expand Up @@ -872,6 +877,7 @@ class BatchNorm3D(_BatchNormBase):
will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable.
If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None.
data_format(str, optional): Specify the input data format, the data format can be "NCDHW" or "NDHWC. Default: NCDHW.
use_global_stats(bool|None, optional): Whether to use global mean and variance. If set to False, use the statistics of one mini-batch, if set to True, use the global statistics, if set to None, use global statistics in the test phase and use the statistics of one mini-batch in the training phase. Default: None.
name(str, optional): Name for the BatchNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..

Shape:
Expand Down