diff --git a/src/operator/batch_norm.cu b/src/operator/batch_norm.cu index 64f7d9373823..9a8b576a16ee 100644 --- a/src/operator/batch_norm.cu +++ b/src/operator/batch_norm.cu @@ -283,7 +283,7 @@ __global__ void BatchNormalizationUpdateOutputKernel( } // Write normalized and update the output - const AccReal gamma = weight.numElements() > 0 + const AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() > 0) ? ScalarConvert::to(weight[plane]) : ScalarConvert::to(1); const AccReal beta = bias.numElements() > 0 ? ScalarConvert::to(bias[plane]) @@ -332,7 +332,7 @@ static __global__ void BatchNormalizationBackwardKernel( invstd = VARIANCE_TO_INVSTD(tensors.runningVar[plane], eps); } - const AccReal weightVal = tensors.weight.numElements() > 0 ? + const AccReal weightVal = ((flags & FIX_GAMMA_FLAG) == 0 && tensors.weight.numElements() > 0) ? ScalarConvert::to(tensors.weight[plane]) : AccReal(1); const AccReal norm = AccReal(1) / N; diff --git a/src/operator/random/sample_op.cu b/src/operator/random/sample_op.cu index 0d4b2e5a8270..7bdb9faf334e 100644 --- a/src/operator/random/sample_op.cu +++ b/src/operator/random/sample_op.cu @@ -28,21 +28,20 @@ namespace op { // GPU versions of uniform and normal distribution. template<> -void SampleUniform_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { +void SampleUniformDnsImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const OpReqType& req, + TBlob* output) { using namespace mxnet::op; using namespace mshadow::expr; typedef gpu xpu; mshadow::Stream *s = ctx.get_stream(); const SampleUniformParam& param = nnvm::get(attrs.parsed); mshadow::Random *prnd = ctx.requested[0].get_random(s); - if (outputs[0].type_flag_ != mshadow::kFloat32) { - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + if (output->type_flag_ != mshadow::kFloat32) { + MSHADOW_REAL_TYPE_SWITCH(output->type_flag_, DType, { // Not float32: use workspace and copy to output - mshadow::Tensor out = outputs[0].FlatTo2D(s); + mshadow::Tensor out = output->FlatTo2D(s); mshadow::Tensor workspace = ctx.requested[1].get_space_typed (mshadow::Shape1(out.shape_.Size()), s); @@ -51,27 +50,36 @@ void SampleUniform_(const nnvm::NodeAttrs& attrs, }); } else { // float32: write directly into output - mshadow::Tensor out = outputs[0].FlatTo2D(s); + mshadow::Tensor out = output->FlatTo2D(s); prnd->SampleUniform(&out, param.low, param.high); } } template<> -void SampleNormal_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { +void SampleUniform_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + TBlob out = outputs[0]; + SampleUniformDnsImpl(attrs, ctx, req[0], &out); +} + +template<> +void SampleNormalDnsImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const OpReqType& req, + TBlob* output) { using namespace mxnet::op; using namespace mshadow::expr; typedef gpu xpu; mshadow::Stream *s = ctx.get_stream(); const SampleNormalParam& param = nnvm::get(attrs.parsed); mshadow::Random *prnd = ctx.requested[0].get_random(s); - if (outputs[0].type_flag_ != mshadow::kFloat32) { - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + if (output->type_flag_ != mshadow::kFloat32) { + MSHADOW_REAL_TYPE_SWITCH(output->type_flag_, DType, { // Not float32: use workspace and copy to output - mshadow::Tensor out = outputs[0].FlatTo2D(s); + mshadow::Tensor out = output->FlatTo2D(s); mshadow::Tensor workspace = ctx.requested[1].get_space_typed (mshadow::Shape1(out.shape_.Size()), s); @@ -80,16 +88,28 @@ void SampleNormal_(const nnvm::NodeAttrs& attrs, }); } else { // float32: write directly into output - mshadow::Tensor out = outputs[0].FlatTo2D(s); + mshadow::Tensor out = output->FlatTo2D(s); prnd->SampleGaussian(&out, param.loc, param.scale); } } +template<> +void SampleNormal_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + TBlob out = outputs[0]; + SampleNormalDnsImpl(attrs, ctx, req[0], &out); +} + NNVM_REGISTER_OP(random_uniform) -.set_attr("FCompute", SampleUniform_); +.set_attr("FCompute", SampleUniform_) +.set_attr("FComputeEx", SampleUniformEx_); NNVM_REGISTER_OP(random_normal) -.set_attr("FCompute", SampleNormal_); +.set_attr("FCompute", SampleNormal_) +.set_attr("FComputeEx", SampleNormalEx_); } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index cb8cfc4b73f4..fe4841bc0979 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -867,7 +867,6 @@ def check_batchnorm_training(stype): rolling_mean = np.random.uniform(size=s) rolling_std = np.random.uniform(size=s) - stype = 'row_sparse' data = mx.symbol.Variable('data', stype=stype) in_location = [mx.nd.array(data_tmp).tostype(stype), mx.nd.array(gamma).tostype(stype), mx.nd.array(beta).tostype(stype)] @@ -935,7 +934,7 @@ def check_batchnorm_training(stype): test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True, axis=chaxis) check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01) - stypes = ['row_sparse', 'csr', 'default'] + stypes = ['row_sparse', 'default'] for stype in stypes: check_batchnorm_training(stype) diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index aecec1df9b84..1849bf7107e4 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -352,6 +352,10 @@ def test_sparse_nd_output_fallback(): assert(np.sum(out.asnumpy()) != 0) def test_sparse_nd_random(): + """ test sparse random operator on cpu """ + # gpu random operator doesn't use fixed seed + if default_context().device_type is 'gpu': + return shape = (100, 100) fns = [mx.nd.random_uniform, mx.nd.random_normal, mx.nd.random_gamma] for fn in fns: