From a96d40f6c65cf90edbf97d20e7d26dcd0d14d273 Mon Sep 17 00:00:00 2001 From: Chris Olivier Date: Wed, 22 Nov 2017 14:35:54 -0800 Subject: [PATCH] cast scalar value in invoke to float (#8778) --- python/mxnet/optimizer.py | 4 ++-- python/mxnet/symbol/symbol.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 5eb4f05d6dca..013455614f37 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -793,9 +793,9 @@ def update(self, index, weight, grad, state): srt = op.sqrt(adjusted_add) div = _internal._scatter_elemwise_div(grad, srt) retained_weight = sparse.retain(weight, grad.indices) - to_add = sparse.elemwise_add(div, _internal._mul_scalar(retained_weight, wd)) + to_add = sparse.elemwise_add(div, _internal._mul_scalar(retained_weight, float(wd))) assert len(to_add.indices) == grad_indices_count - weight[:] = sparse.elemwise_add(weight, _internal._mul_scalar(to_add, -lr)) + weight[:] = sparse.elemwise_add(weight, _internal._mul_scalar(to_add, float(-lr))) state[:] = history assert state.stype == save_history_stype assert len(history_indices) == grad_indices_count diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index e2cf0ecb68f3..ce7776d94844 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -2759,7 +2759,7 @@ def full(shape, val, dtype=None, **kwargs): """ if dtype is None: dtype = _numpy.float32 - return _internal._full(shape=shape, dtype=dtype, value=val, **kwargs) + return _internal._full(shape=shape, dtype=dtype, value=float(val), **kwargs) # pylint: disable=redefined-outer-name def arange(start, stop=None, step=1.0, repeat=1, name=None, dtype=None):