Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Enable slice embedding concat split fuse #14491

Open
wants to merge 37 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
ea35b12
enable slice embedding concat split fuse
JustForFun099 Mar 21, 2019
19529ca
fix code style
JustForFun099 Mar 21, 2019
70d75c9
Merge remote-tracking branch 'upstream/master' into fuseop
JustForFun099 Mar 25, 2019
bb00fb9
retrigger ci
JustForFun099 Mar 25, 2019
db29910
retrigger ci
JustForFun099 Mar 25, 2019
3fb6138
Merge remote-tracking branch 'upstream/master' into fuseop
JustForFun099 Mar 26, 2019
be7ded1
fix code style
JustForFun099 Mar 26, 2019
2851998
fix clang error
JustForFun099 Mar 26, 2019
a62565d
remove local copy
JustForFun099 Mar 27, 2019
415b989
Merge remote-tracking branch 'upstream/master' into fuseop
JustForFun099 Mar 27, 2019
ef952ab
retrigger ci
JustForFun099 Mar 27, 2019
c8c3ad3
Merge remote-tracking branch 'upstream/master' into fuseop
JustForFun099 Mar 28, 2019
a7cf220
add test case
Mar 28, 2019
8be9906
retrigger ci
JustForFun099 Mar 28, 2019
7adf09a
Merge remote-tracking branch 'upstream/master' into fuseop
JustForFun099 Mar 29, 2019
b53fa84
change the testcase to mkldnn file path
JustForFun099 Mar 29, 2019
c938039
retrigger ci
JustForFun099 Mar 29, 2019
d91a106
Merge remote-tracking branch 'upstream/master' into fuseop
Apr 10, 2019
67f8bb1
retrigger ci
JustForFun099 Apr 10, 2019
e81be48
retrigger ci
JustForFun099 Apr 11, 2019
9bec922
change test case path
JustForFun099 Apr 11, 2019
2a52a1c
skip test case for gpu
JustForFun099 Apr 11, 2019
55f68bd
Merge remote-tracking branch 'upstream/master' into fuseop
JustForFun099 Apr 12, 2019
557e3e7
retrigger ci
JustForFun099 Apr 12, 2019
c3e9431
Merge remote-tracking branch 'upstream/master' into fuseop
JustForFun099 Apr 14, 2019
f897e4e
retrigger ci
JustForFun099 Apr 14, 2019
39b9433
Merge remote-tracking branch 'upstream/master' into fuseop
JustForFun099 Apr 15, 2019
3989fe2
fix conflict
JustForFun099 Apr 18, 2019
deea19b
Merge remote-tracking branch 'upstream/master' into fuseop
JustForFun099 Apr 18, 2019
7ffed68
fix conflict issue
JustForFun099 Apr 18, 2019
59f82b7
Merge remote-tracking branch 'upstream/master' into fuseop
JustForFun099 Apr 19, 2019
5217e36
fix testcase fail issuse
JustForFun099 Apr 19, 2019
d4b8c89
fix code style
JustForFun099 Apr 19, 2019
c4dc8b7
retrigger ci
JustForFun099 Apr 19, 2019
fdf35ae
retrigger ci
JustForFun099 Apr 19, 2019
b93a14d
retrigger ci
JustForFun099 Apr 19, 2019
30cd27d
retrigger ci
JustForFun099 Apr 19, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/operator/nn/concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,19 @@
namespace mxnet {
namespace op {

static bool ConcatShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape) {
bool ConcatSetShape(mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape, int num_args, int dim) {
using namespace mshadow;
const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));

CHECK_EQ(in_shape->size(), static_cast<size_t>(num_args));
mxnet::TShape dshape;
dim_t size = 0;
bool has_unknown_dim_size = false;
int axis = -1;
for (int i = 0; i < param_.num_args; ++i) {
for (int i = 0; i < num_args; ++i) {
mxnet::TShape tmp = (*in_shape)[i];
if (tmp.ndim() > 0) {
axis = CheckAxis(param_.dim, tmp.ndim());
axis = CheckAxis(dim, tmp.ndim());
has_unknown_dim_size = !mxnet::dim_size_is_known(tmp, axis) || has_unknown_dim_size;
size += tmp[axis];
tmp[axis] = -1;
Expand All @@ -55,15 +54,15 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs,

mxnet::TShape tmp = (*out_shape)[0];
if (tmp.ndim() > 0) {
axis = CheckAxis(param_.dim, tmp.ndim());
axis = CheckAxis(dim, tmp.ndim());
tmp[axis] = -1;
shape_assign(&dshape, tmp);
}

if (dshape.ndim() == -1) return false;
CHECK_NE(dshape.ndim(), 0) << "zero-dimensional arrays cannot be concatenated";

for (int i = 0; i < param_.num_args; ++i) {
for (int i = 0; i < num_args; ++i) {
CHECK(shape_assign(&(*in_shape)[i], dshape))
<< "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
}
Expand All @@ -74,7 +73,12 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs,

return shape_is_known(dshape);
}

static bool ConcatShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape) {
const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
return ConcatSetShape(in_shape, out_shape, param_.num_args, param_.dim);
}
// Concat for RNN param deals with the reverse shape inference from output
// for the special case of concatenating RNN parameters.
// The first (and sometimes the second) input may be unknown on the target axis.
Expand Down
135 changes: 71 additions & 64 deletions src/operator/slice_channel-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,75 @@ class SliceChannelOp : public Operator {

template<typename xpu>
Operator *CreateOp(SliceChannelParam param, int dtype);

inline bool SliceChannelInferShape(mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape,
mxnet::ShapeVector *aux_shape,
int num_outputs, int axis, bool squeeze_axis) {
using namespace mshadow;
CHECK_EQ(in_shape->size(), 1U);
mxnet::TShape dshape = in_shape->at(slice_enum::kData);
mxnet::TShape ishape = in_shape->at(slice_enum::kData);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does in_shape->at have to accessed twice? can we do it once and reuse?

if (!mxnet::ndim_is_known(dshape)) return false;
if (axis >= 0) {
CHECK_LT(axis, dshape.ndim());
} else {
CHECK_LT(axis + dshape.ndim(), dshape.ndim());
}
int real_axis = axis;
if (real_axis < 0) {
real_axis += dshape.ndim();
}
CHECK_EQ(dshape[real_axis] % num_outputs, 0U)
<< "You are trying to split the " << real_axis
<< "-th axis of input tensor with shape " << dshape
<< " into num_outputs=" << num_outputs
<< " evenly sized chunks, but this is not possible because "
<< num_outputs << " does not evenly divide "
<< dshape[real_axis];
if (squeeze_axis && ishape[real_axis] != -1) {
CHECK_EQ(ishape[real_axis], num_outputs)
<< "If squeeze axis is True, the size of the sliced axis"
<< " must be the same as num_outputs."
<< " Input shape=" << ishape << ", axis=" << real_axis
<< ", num_outputs=" << num_outputs << ".";
}
if (dshape[real_axis] >= 0) {
dshape[real_axis] /= num_outputs;
}
if (squeeze_axis && (dshape[real_axis] == 1
|| !mxnet::dim_size_is_known(ishape, real_axis))) {
for (int d = real_axis; d < dshape.ndim() - 1; ++d) {
dshape[d] = dshape[d + 1];
}
dshape = mxnet::TShape(&dshape[0], &dshape[dshape.ndim() - 1]);
}
CHECK_EQ(static_cast<int>((*out_shape).size()), num_outputs)
<< "Size of output shape mismatch!";
for (int i = 0; i < num_outputs; ++i) {
SHAPE_ASSIGN_CHECK(*out_shape, i, dshape);
// Perform incomplete shape inference.
// We can back-calculate the inshape based on the out_shape.
mxnet::TShape back_calculate_dshape = ishape;
if (squeeze_axis && (dshape.ndim() == ishape.ndim() - 1)) {
for (int d = 0; d < real_axis; ++d) {
back_calculate_dshape[d] = (*out_shape)[i][d];
}
back_calculate_dshape[real_axis] = num_outputs;
for (int d = real_axis + 1; d < static_cast<int>(ishape.ndim()); ++d) {
back_calculate_dshape[d] = (*out_shape)[i][d - 1];
}
} else {
for (int d = 0; d < static_cast<int>(ishape.ndim()); ++d) {
back_calculate_dshape[d] = (*out_shape)[i][d];
if (d == real_axis) {
back_calculate_dshape[d] *= num_outputs;
}
}
}
SHAPE_ASSIGN_CHECK(*in_shape, slice_enum::kData, back_calculate_dshape);
}
return true;
}

#if DMLC_USE_CXX11
class SliceChannelProp : public OperatorProperty {
Expand Down Expand Up @@ -191,69 +259,8 @@ class SliceChannelProp : public OperatorProperty {
bool InferShape(mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape,
mxnet::ShapeVector *aux_shape) const override {
using namespace mshadow;
CHECK_EQ(in_shape->size(), 1U);
mxnet::TShape dshape = in_shape->at(slice_enum::kData);
mxnet::TShape ishape = in_shape->at(slice_enum::kData);
if (!mxnet::ndim_is_known(dshape)) return false;
if (param_.axis >= 0) {
CHECK_LT(param_.axis, dshape.ndim());
} else {
CHECK_LT(param_.axis + dshape.ndim(), dshape.ndim());
}
int real_axis = param_.axis;
if (real_axis < 0) {
real_axis += dshape.ndim();
}
CHECK_EQ(dshape[real_axis] % param_.num_outputs, 0U)
<< "You are trying to split the " << real_axis
<< "-th axis of input tensor with shape " << dshape
<< " into num_outputs=" << param_.num_outputs
<< " evenly sized chunks, but this is not possible because "
<< param_.num_outputs << " does not evenly divide "
<< dshape[real_axis];
if (param_.squeeze_axis && ishape[real_axis] != -1) {
CHECK_EQ(ishape[real_axis], param_.num_outputs)
<< "If squeeze axis is True, the size of the sliced axis must be the same as num_outputs."
<< " Input shape=" << ishape << ", axis=" << real_axis
<< ", num_outputs=" << param_.num_outputs << ".";
}
if (dshape[real_axis] >= 0) {
dshape[real_axis] /= param_.num_outputs;
}
if (param_.squeeze_axis && (dshape[real_axis] == 1
|| !mxnet::dim_size_is_known(ishape, real_axis))) {
for (int d = real_axis; d < dshape.ndim() - 1; ++d) {
dshape[d] = dshape[d+1];
}
dshape = mxnet::TShape(&dshape[0], &dshape[dshape.ndim()-1]);
}
CHECK_EQ(static_cast<int>((*out_shape).size()), param_.num_outputs)
<< "Size of output shape mismatch!";
for (int i = 0; i < param_.num_outputs; ++i) {
SHAPE_ASSIGN_CHECK(*out_shape, i, dshape);
// Perform incomplete shape inference.
// We can back-calculate the inshape based on the out_shape.
mxnet::TShape back_calculate_dshape = ishape;
if (param_.squeeze_axis && (dshape.ndim() == ishape.ndim() - 1)) {
for (int d = 0; d < real_axis; ++d) {
back_calculate_dshape[d] = (*out_shape)[i][d];
}
back_calculate_dshape[real_axis] = param_.num_outputs;
for (int d = real_axis + 1; d < static_cast<int>(ishape.ndim()); ++d) {
back_calculate_dshape[d] = (*out_shape)[i][d - 1];
}
} else {
for (int d = 0; d < static_cast<int>(ishape.ndim()); ++d) {
back_calculate_dshape[d] = (*out_shape)[i][d];
if (d == real_axis) {
back_calculate_dshape[d] *= param_.num_outputs;
}
}
}
SHAPE_ASSIGN_CHECK(*in_shape, slice_enum::kData, back_calculate_dshape);
}
return true;
return SliceChannelInferShape(
in_shape, out_shape, aux_shape, param_.num_outputs, param_.axis, param_.squeeze_axis);
}

OperatorProperty* Copy() const override {
Expand Down
Loading