Skip to content

Commit

Permalink
Fix ndarray assignment issue with basic indexing (apache#10022)
Browse files Browse the repository at this point in the history
* Fix ndarray assignment issue with basic index

* Uncomment useful code
  • Loading branch information
reminisce authored and szha committed Mar 8, 2018
1 parent 15a99da commit 4ebb147
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,8 @@ def _set_nd_basic_indexing(self, key, value):
# may need to broadcast first
if isinstance(value, NDArray):
if value.handle is not self.handle:
if value.shape != shape:
value = value.broadcast_to(shape)
value.copyto(self)
elif isinstance(value, numeric_types):
_internal._full(shape=shape, ctx=self.context,
Expand Down
5 changes: 5 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,8 @@ def test_setitem(np_array, index, is_scalar):
def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None):
if np_value is not None:
np_array[np_index] = np_value
elif isinstance(mx_value, mx.nd.NDArray):
np_array[np_index] = mx_value.asnumpy()
else:
np_array[np_index] = mx_value
mx_array[mx_index] = mx_value
Expand Down Expand Up @@ -1024,6 +1026,9 @@ def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None)
# test value is an numeric_type
assert_same(np_array, np_index, mx_array, index, np.random.randint(low=-10000, high=0))
if len(indexed_array_shape) > 1:
# test NDArray with broadcast
assert_same(np_array, np_index, mx_array, index,
mx.nd.random.uniform(low=-10000, high=0, shape=(indexed_array_shape[-1],)))
# test numpy array with broadcast
assert_same(np_array, np_index, mx_array, index,
np.random.randint(low=-10000, high=0, size=(indexed_array_shape[-1],)))
Expand Down

0 comments on commit 4ebb147

Please sign in to comment.