Skip to content

Commit

Permalink
make array.reshape compatible with numpy (apache#9790)
Browse files Browse the repository at this point in the history
* make array.reshape compatible with numpy

* update

* add exception when both *args and **kwargs are specified

* update
  • Loading branch information
szha authored and piiswrong committed Feb 19, 2018
1 parent d048615 commit 9348a3a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 3 deletions.
20 changes: 18 additions & 2 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,12 +926,12 @@ def _at(self, idx):
self.handle, mx_uint(idx), ctypes.byref(handle)))
return NDArray(handle=handle, writable=self.writable)

def reshape(self, shape):
def reshape(self, *shape, **kwargs):
"""Returns a **view** of this array with a new shape without altering any data.
Parameters
----------
shape : tuple of int
shape : tuple of int, or n ints
The new shape should not change the array size, namely
``np.prod(new_shape)`` should be equal to ``np.prod(self.shape)``.
Expand Down Expand Up @@ -960,6 +960,11 @@ def reshape(self, shape):
[ 4., 5.]], dtype=float32)
>>> y = x.reshape((3,-1))
>>> y.asnumpy()
array([[ 0., 1.],
[ 2., 3.],
[ 4., 5.]], dtype=float32)
>>> y = x.reshape(3,2)
>>> y.asnumpy()
array([[ 0., 1.],
[ 2., 3.],
[ 4., 5.]], dtype=float32)
Expand All @@ -968,6 +973,17 @@ def reshape(self, shape):
array([[-1., -1., -1.],
[-1., -1., -1.]], dtype=float32)
"""
if len(shape) == 1 and isinstance(shape[0], (list, tuple)):
shape = shape[0]
elif not shape:
shape = kwargs.get('shape')
assert shape, "Shape must be provided."
if len(kwargs) != 1:
raise TypeError("Only 'shape' is supported as keyword argument. Got: {}."
.format(', '.join(kwargs.keys())))
else:
assert not kwargs,\
"Specifying both positional and keyword arguments is not allowed in reshape."
handle = NDArrayHandle()

# Actual reshape
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/ndarray/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _at(self, idx):
def _slice(self, start, stop):
raise NotSupportedForSparseNDArray(self._slice, None, start, stop)

def reshape(self, shape):
def reshape(self, *shape, **kwargs):
raise NotSupportedForSparseNDArray(self.reshape, None, shape)

@property
Expand Down
5 changes: 5 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ def test_ndarray_reshape():
[5, 6],
[7, 8]])
assert same(tensor.reshape((-1, 2)).asnumpy(), true_res.asnumpy())
assert same(tensor.reshape(4, 2).asnumpy(), true_res.asnumpy())
assert same(tensor.reshape(-1, 2).asnumpy(), true_res.asnumpy())
true_res = mx.nd.arange(8) + 1
assert same(tensor.reshape(-1).asnumpy(), true_res.asnumpy())
assert same(tensor.reshape(8).asnumpy(), true_res.asnumpy())


@with_seed()
Expand Down

0 comments on commit 9348a3a

Please sign in to comment.