Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add dstack that pass CPU test
Browse files Browse the repository at this point in the history
Rgister dstack on GPU

Minor comment fix

Minor syntax fix

Syntax fix according to comments

header fix
  • Loading branch information
Mike Mao committed Aug 14, 2019
1 parent 24a5cf0 commit dcdf73b
Show file tree
Hide file tree
Showing 7 changed files with 294 additions and 3 deletions.
18 changes: 17 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ..ndarray import NDArray

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange']
'linspace', 'expand_dims', 'tile', 'arange', 'dstack']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -200,6 +200,22 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou
#pylint: enable= too-many-arguments, no-member, protected-access


@set_module('mxnet.ndarray.numpy')
def dstack(arrays):
"""Stack tensors in sequence depth wise.
This is equivalent to concatenation along the third axis, except for zero
dimensional, 1-D or 2D tensors, in which case the first dimension is used.
Parameters
----------
arrays : sequence of array_like
Each array must have the same shape.
Returns
-------
depth-wisely concatenated ndarray
"""
return _npi.dstack(*arrays)


@set_module('mxnet.ndarray.numpy')
def add(x1, x2, out=None):
"""Add arguments element-wise.
Expand Down
19 changes: 18 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from ..ndarray.numpy import _internal as _npi

__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'add', 'subtract', 'multiply', 'divide',
'mod', 'power', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange']
'mod', 'power', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'dstack']


# This function is copied from ndarray.py since pylint
Expand Down Expand Up @@ -1606,6 +1606,7 @@ def tensordot(a, b, axes=2):
return _mx_nd_np.tensordot(a, b, axes)


@set_module('mxnet.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
r"""
Return evenly spaced numbers over a specified interval.
Expand Down Expand Up @@ -1819,3 +1820,19 @@ def arange(start, stop=None, step=1, dtype=None, ctx=None):
than `stop`.
"""
return _mx_nd_np.arange(start, stop, step, dtype, ctx)


@set_module('mxnet.numpy')
def dstack(arrays):
"""Stack tensors in sequence depth wise.
This is equivalent to concatenation along the third axis, except for zero
dimensional, 1-D or 2D tensors, in which case the first dimension is used.
Parameters
----------
arrays : sequence of array_like
Each array must have the same shape.
Returns
-------
depth-wisely concatenated ndarray
"""
return _npi.dstack(*arrays)
17 changes: 16 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange']
'linspace', 'expand_dims', 'tile', 'arange', 'dstack']


def _num_outputs(sym):
Expand Down Expand Up @@ -1063,6 +1063,7 @@ def tensordot(a, b, axes=2):
return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)


@set_module('mxnet.symbol.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
r"""
Return evenly spaced numbers over a specified interval.
Expand Down Expand Up @@ -1134,6 +1135,20 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis
else:
return _npi.linspace(start=start, stop=stop, num=num, endpoint=endpoint, ctx=ctx, dtype=dtype)

@set_module('mxnet.symbol.numpy')
def dstack(arrays):
"""Stack tensors in sequence depth wise.
This is equivalent to concatenation along the third axis, except for zero
dimensional, 1-D or 2D tensors, in which case the first dimension is used.
Parameters
----------
arrays : sequence of array_like
Each array must have the same shape.
Returns
-------
depth-wisely concatenated ndarray
"""
return _npi.dstack(*arrays)

@set_module('mxnet.symbol.numpy')
def expand_dims(a, axis):
Expand Down
62 changes: 62 additions & 0 deletions src/operator/nn/concat-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,37 @@ void ConcatCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
});
}

template<typename xpu>
void DStackCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
ConcatParam param = nnvm::get<ConcatParam>(attrs.parsed);
param.dim = 2;
std::vector<TBlob> modified_inputs(inputs.size());
for (int i = 0; i < param.num_args; ++i) {
if (inputs[i].shape_.ndim() == 0) {
modified_inputs[i] = inputs[i].reshape(TShape(3, 1));
} else if (inputs[i].shape_.ndim() == 1) {
TShape t = TShape(3, 1);
t[1] = inputs[i].shape_[0];
modified_inputs[i] = inputs[i].reshape(t);
} else if (inputs[i].shape_.ndim() == 2) {
TShape t = TShape(3, 1);
t[0] = inputs[i].shape_[0];
t[1] = inputs[i].shape_[1];
modified_inputs[i] = inputs[i].reshape(t);
} else {
modified_inputs[i] = inputs[i];
}
}
MSHADOW_TYPE_SWITCH(inputs[concat_enum::kData0].type_flag_, DType, {
ConcatOp<xpu, DType> op;
op.Init(param);
op.Forward(ctx, modified_inputs, req, outputs);
});
}

template<typename xpu>
void ConcatGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
const std::vector<TBlob>& inputs,
Expand All @@ -154,6 +185,37 @@ void ConcatGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
});
}

template<typename xpu>
void DStackGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
ConcatParam param = nnvm::get<ConcatParam>(attrs.parsed);
param.dim = 2;
std::vector<TBlob> modified_outputs(outputs.size());
for (int i = 0; i < param.num_args; ++i) {
if (outputs[i].shape_.ndim() == 0) {
modified_outputs[i] = outputs[i].reshape(TShape(3, 1));
} else if (outputs[i].shape_.ndim() == 1) {
TShape t = TShape(3, 1);
t[1] = outputs[i].shape_[0];
modified_outputs[i] = outputs[i].reshape(t);
} else if (outputs[i].shape_.ndim() == 2) {
TShape t = TShape(3, 1);
t[0] = outputs[i].shape_[0];
t[1] = outputs[i].shape_[1];
modified_outputs[i] = outputs[i].reshape(t);
} else {
modified_outputs[i] = outputs[i];
}
}
MSHADOW_TYPE_SWITCH(inputs[concat_enum::kOut].type_flag_, DType, {
ConcatOp<xpu, DType> op;
op.Init(param);
op.Backward(ctx, inputs[concat_enum::kOut], req, modified_outputs);
});
}

/*!
* \brief concat CSRNDArray on the first dimension.
*/
Expand Down
113 changes: 113 additions & 0 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,5 +248,118 @@ NNVM_REGISTER_OP(_np_squeeze)
.add_argument("a", "NDArray-or-Symbol[]", "data to squeeze")
.add_arguments(SqueezeParam::__FIELDS__());

bool DStackShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape) {
using namespace mshadow;
ConcatParam param_ = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
mxnet::TShape dshape;
dim_t size = 0;
bool has_unknown_dim_size = false;
int axis = 2;
param_.dim = axis;
for (int i = 0; i < param_.num_args; ++i) {
if ((*in_shape)[i].ndim() == 0) {
(*in_shape)[i] = mxnet::TShape(3, 1);
} else if ((*in_shape)[i].ndim() == 1) {
mxnet::TShape t = mxnet::TShape(3, 1);
t[1] = (*in_shape)[i][0];
(*in_shape)[i] = t;
} else if ((*in_shape)[i].ndim() == 2) {
mxnet::TShape t = mxnet::TShape(3, 1);
t[0] = (*in_shape)[i][0];
t[1] = (*in_shape)[i][1];
(*in_shape)[i] = t;
}
mxnet::TShape &tmp = (*in_shape)[i];
if (tmp.ndim() > 0) {
CheckAxis(axis, tmp.ndim());
if (!mxnet::dim_size_is_known(tmp, axis)) {
has_unknown_dim_size = true;
} else {
size += tmp[axis];
}
tmp[axis] = -1;
shape_assign(&dshape, tmp);
}
}

mxnet::TShape tmp = (*out_shape)[0];
if (tmp.ndim() > 0) {
axis = CheckAxis(param_.dim, tmp.ndim());
tmp[axis] = -1;
shape_assign(&dshape, tmp);
}

if (dshape.ndim() == -1) return false;
CHECK_NE(dshape.ndim(), 0) << "zero-dimensional arrays cannot be concatenated";

for (int i = 0; i < param_.num_args; ++i) {
CHECK(shape_assign(&(*in_shape)[i], dshape))
<< "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
}

if (!has_unknown_dim_size) {
dshape[axis] = size;
}
CHECK(shape_assign(&(*out_shape)[0], dshape))
<< "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];

return shape_is_known(dshape);
}

bool ConcatType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type,
std::vector<int> *out_type);

struct NumpyConcatGrad {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) const {
CHECK_EQ(ograds.size(), 1);
std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};

NNVM_REGISTER_OP(_npi_dstack)
.describe(R"code(Stack tensors in sequence depthwise (in third dimension))code" ADD_FILELINE)
.set_num_inputs([](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return params.num_args;
})
.set_num_outputs(1)
.set_attr_parser(ParamParser<ConcatParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
std::vector<std::string> ret;
for (int i = 0; i < params.num_args; ++i) {
ret.push_back(std::string("data") + std::to_string(i));
}
return ret;
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"out"};
})
.set_attr<std::string>("key_var_num_args", "num_args")
.set_attr<nnvm::FInferType>("FInferType", ConcatType)
.set_attr<mxnet::FInferShape>("FInferShape", DStackShape)
.set_attr<FCompute>("FCompute<cpu>", DStackCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", NumpyConcatGrad{"_backward_np_dstack"})
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
.add_arguments(ConcatParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_np_dstack)
.set_num_outputs([](const NodeAttrs& attrs) {
const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
return params.num_args;
})
.set_attr_parser(ParamParser<ConcatParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", DStackGradCompute<cpu>);

} // namespace op
} // namespace mxnet
7 changes: 7 additions & 0 deletions src/operator/numpy/np_matrix_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,12 @@ NNVM_REGISTER_OP(_np_reshape)
NNVM_REGISTER_OP(_np_squeeze)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);


NNVM_REGISTER_OP(_npi_dstack)
.set_attr<FCompute>("FCompute<gpu>", DStackCompute<gpu>);

NNVM_REGISTER_OP(_backward_np_dstack)
.set_attr<FCompute>("FCompute<gpu>", DStackGradCompute<gpu>);

} // namespace op
} // namespace mxnet
61 changes: 61 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,67 @@ def hybrid_forward(self, F, x):
assert same(mx_out.asnumpy(), np_out)


@with_seed()
@use_np
def test_np_dstack():
class TestDStack(HybridBlock):
def __init__(self):
super(TestDStack, self).__init__()

def hybrid_forward(self, F, a, *args):
return F.np.dstack([a] + list(args))

def get_new_shape(shape):
if len(shape) < 3:
return shape
axis = 2
shape_lst = list(shape)
shape_lst[axis] = random.randint(0, 5)
return tuple(shape_lst)

shapes = [
(),
(1,),
(2,1),
(2,2,4),
(2,0,0),
(0,1,3),
(2,0,3),
(2,3,4,5)
]
for hybridize in [True, False]:
for shape in shapes:
test_dstack = TestDStack()
if hybridize:
test_dstack.hybridize()
# test symbolic forward
a = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray()
a.attach_grad()
b = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray()
b.attach_grad()
c = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray()
c.attach_grad()
d = mx.nd.random.uniform(shape=get_new_shape(shape)).as_np_ndarray()
d.attach_grad()
with mx.autograd.record():
mx_out = test_dstack(a, b, c, d)
np_out = _np.dstack((a.asnumpy(), b.asnumpy(), c.asnumpy(), d.asnumpy()))
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)

# test symbolic backward
mx_out.backward()
assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(b.grad.asnumpy(), _np.ones(b.shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(c.grad.asnumpy(), _np.ones(c.shape), rtol=1e-3, atol=1e-5)
assert_almost_equal(d.grad.asnumpy(), _np.ones(d.shape), rtol=1e-3, atol=1e-5)

# test imperative
mx_out = np.dstack((a, b, c, d))
np_out = _np.dstack((a.asnumpy(),b.asnumpy(), c.asnumpy(), d.asnumpy()))
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit dcdf73b

Please sign in to comment.