Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add reverse option to ndarray inplace reshape (#10767)
Browse files Browse the repository at this point in the history
* add reverse option to ndarray inplace reshape

* update check
  • Loading branch information
szha authored and piiswrong committed May 3, 2018
1 parent 20bbef5 commit 66b2944
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 30 deletions.
1 change: 1 addition & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,7 @@ MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle,
MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
int ndim,
dim_t *dims,
bool reverse,
NDArrayHandle *out);
/*!
* \brief get the shape of the array
Expand Down
26 changes: 20 additions & 6 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,19 @@ def reshape(self, *shape, **kwargs):
- input shape = (2,3,4), shape = (-4,1,2,-2), output shape =(1,2,3,4)
- input shape = (2,3,4), shape = (2,-4,-1,3,-2), output shape = (2,1,3,4)
- If the argument `reverse` is set to 1, then the special values are inferred from right
to left.
Example::
- without reverse=1, for input shape = (10,5,4), shape = (-1,0), output shape would be
(40,5).
- with reverse=1, output shape will be (50,4).
reverse : bool, default False
If true then the special values are inferred from right to left. Only supported as
keyword argument.
Returns
-------
Expand Down Expand Up @@ -1029,18 +1042,19 @@ def reshape(self, *shape, **kwargs):
elif not shape:
shape = kwargs.get('shape')
assert shape, "Shape must be provided."
if len(kwargs) != 1:
raise TypeError("Only 'shape' is supported as keyword argument. Got: {}."
.format(', '.join(kwargs.keys())))
else:
assert not kwargs,\
"Specifying both positional and keyword arguments is not allowed in reshape."
if not all(k in ['shape', 'reverse'] for k in kwargs):
raise TypeError(
"Got unknown keywords in reshape: {}. " \
"Accepted keyword arguments are 'shape' and 'reverse'.".format(
', '.join([k for k in kwargs if k not in ['shape', 'reverse']])))
reverse = kwargs.get('reverse', False)
handle = NDArrayHandle()

# Actual reshape
check_call(_LIB.MXNDArrayReshape64(self.handle,
len(shape),
c_array(ctypes.c_int64, shape),
reverse,
ctypes.byref(handle)))
return NDArray(handle=handle, writable=self.writable)

Expand Down
3 changes: 2 additions & 1 deletion src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -431,12 +431,13 @@ MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle,
MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
int ndim,
dim_t *dims,
bool reverse,
NDArrayHandle *out) {
NDArray *ptr = new NDArray();
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
nnvm::Tuple<dim_t> shape(dims, dims+ndim);
TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(), false);
TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(), reverse);
*ptr = arr->ReshapeWithRecord(new_shape);
*out = ptr;
API_END_HANDLE_ERROR(delete ptr);
Expand Down
39 changes: 16 additions & 23 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,30 +154,23 @@ def test_ndarray_negate():

@with_seed()
def test_ndarray_reshape():
tensor = mx.nd.array([[[1, 2], [3, 4]],
[[5, 6], [7, 8]]])
true_res = mx.nd.arange(8) + 1
assert same(tensor.reshape((-1, )).asnumpy(), true_res.asnumpy())
true_res = mx.nd.array([[1, 2, 3, 4],
[5, 6, 7, 8]])
assert same(tensor.reshape((2, -1)).asnumpy(), true_res.asnumpy())
assert same(tensor.reshape((0, -1)).asnumpy(), true_res.asnumpy())
true_res = mx.nd.array([[1, 2],
[3, 4],
[5, 6],
[7, 8]])
assert same(tensor.reshape((-1, 2)).asnumpy(), true_res.asnumpy())
assert same(tensor.reshape(4, 2).asnumpy(), true_res.asnumpy())
assert same(tensor.reshape(-1, 2).asnumpy(), true_res.asnumpy())
true_res = mx.nd.arange(8) + 1
tensor = (mx.nd.arange(30) + 1).reshape(2, 3, 5)
true_res = mx.nd.arange(30) + 1
assert same(tensor.reshape((-1,)).asnumpy(), true_res.asnumpy())
assert same(tensor.reshape((2, -1)).asnumpy(), true_res.reshape(2, 15).asnumpy())
assert same(tensor.reshape((0, -1)).asnumpy(), true_res.reshape(2, 15).asnumpy())
assert same(tensor.reshape((-1, 2)).asnumpy(), true_res.reshape(15, 2).asnumpy())
assert same(tensor.reshape(6, 5).asnumpy(), true_res.reshape(6, 5).asnumpy())
assert same(tensor.reshape(-1, 2).asnumpy(), true_res.reshape(15, 2).asnumpy())
assert same(tensor.reshape(-1).asnumpy(), true_res.asnumpy())
assert same(tensor.reshape(8).asnumpy(), true_res.asnumpy())

assert same(tensor.reshape(0, -1).asnumpy(), true_res.reshape(2, 4).asnumpy())
assert same(tensor.reshape(-1, 4).asnumpy(), true_res.reshape(2, 4).asnumpy())
assert same(tensor.reshape(-2,).asnumpy(), true_res.reshape(2, 2, 2).asnumpy())
assert same(tensor.reshape(-3, -1).asnumpy(), true_res.reshape(4, 2).asnumpy())
assert same(tensor.reshape(-1, 4).reshape(0, -4, 2, -1).asnumpy(), true_res.reshape(2, 2, 2).asnumpy())
assert same(tensor.reshape(30).asnumpy(), true_res.asnumpy())
assert same(tensor.reshape(0, -1).asnumpy(), true_res.reshape(2, 15).asnumpy())
assert same(tensor.reshape(-1, 6).asnumpy(), true_res.reshape(5, 6).asnumpy())
assert same(tensor.reshape(-2,).asnumpy(), true_res.reshape(2, 3, 5).asnumpy())
assert same(tensor.reshape(-3, -1).asnumpy(), true_res.reshape(6, 5).asnumpy())
assert same(tensor.reshape(-1, 15).reshape(0, -4, 3, -1).asnumpy(), true_res.reshape(2, 3, 5).asnumpy())
assert same(tensor.reshape(-1, 0).asnumpy(), true_res.reshape(10, 3).asnumpy())
assert same(tensor.reshape(-1, 0, reverse=True).asnumpy(), true_res.reshape(6, 5).asnumpy())


@with_seed()
Expand Down

0 comments on commit 66b2944

Please sign in to comment.