diff --git a/python/mxnet/initializer.py b/python/mxnet/initializer.py index 990dbbf1..e5eaea99 100644 --- a/python/mxnet/initializer.py +++ b/python/mxnet/initializer.py @@ -40,7 +40,7 @@ def __call__(self, name, arr): elif name.endswith("moving_mean"): self._init_zero(name, arr) elif name.endswith("moving_var"): - self._init_zero(name, arr) + self._init_one(name, arr) elif name.endswith("moving_inv_var"): self._init_zero(name, arr) elif name.endswith("moving_avg"): @@ -62,6 +62,9 @@ def _init_bilinear(self, _, arr): def _init_zero(self, _, arr): arr[:] = 0.0 + def _init_one(self, _, arr): + arr[:] = 1.0 + def _init_bias(self, _, arr): arr[:] = 0.0 diff --git a/src/operator/cudnn_batch_norm-inl.h b/src/operator/cudnn_batch_norm-inl.h index cc94b363..fc3bd86d 100644 --- a/src/operator/cudnn_batch_norm-inl.h +++ b/src/operator/cudnn_batch_norm-inl.h @@ -115,7 +115,7 @@ class CuDNNBatchNormOp : public Operator { mean_desc_, gamma.dptr_, beta.dptr_, - param_.momentum, + 1 - param_.momentum, moving_mean.dptr_, moving_inv_var.dptr_, param_.eps,