Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/operator/contrib/batch_norm_relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ static bool BatchNormWithReLUShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]";
CHECK_EQ(out_shape->size(), 4U);
const mxnet::TShape &dshape = in_shape->at(batchnormrelu::kData);
if (!mxnet::ndim_is_known(dshape)) {
return false;
}

const size_t channelAxis = static_cast<size_t>(param.axis < 0
? static_cast<int>(dshape.ndim()) + param.axis
Expand All @@ -63,10 +66,6 @@ static bool BatchNormWithReLUShape(const nnvm::NodeAttrs& attrs,

const int channelCount = dshape[channelAxis];

if (!mxnet::ndim_is_known(dshape)) {
return false;
}

in_shape->at(batchnormrelu::kGamma) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnormrelu::kBeta) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnormrelu::kInMovingMean) = mxnet::TShape(Shape1(channelCount)); // kMovingMean
Expand Down
7 changes: 3 additions & 4 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,9 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]";
CHECK_EQ(out_shape->size(), 3U);
const mxnet::TShape &dshape = in_shape->at(batchnorm::kData);
if (!mxnet::ndim_is_known(dshape)) {
return false;
}

const size_t channelAxis = static_cast<size_t>(param.axis < 0
? static_cast<int>(dshape.ndim()) + param.axis
Expand All @@ -331,10 +334,6 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,

const index_t channelCount = dshape[channelAxis];

if (!mxnet::ndim_is_known(dshape)) {
return false;
}

in_shape->at(batchnorm::kGamma) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnorm::kBeta) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnorm::kInMovingMean) = mxnet::TShape(Shape1(channelCount)); // kMovingMean
Expand Down
8 changes: 4 additions & 4 deletions src/operator/nn/group_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ static bool GroupNormShape(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
const mxnet::TShape &dshape = in_shape->at(groupnorm::kData);
CHECK_GE(dshape.ndim(), 3U);
const int num_groups = param.num_groups;
CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by # of groups";

if (!mxnet::ndim_is_known(dshape)) {
return false;
}

CHECK_GE(dshape.ndim(), 3U);
const int num_groups = param.num_groups;
CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by # of groups";

in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(dshape[1]));
in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(dshape[1]));

Expand Down
7 changes: 4 additions & 3 deletions src/operator/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,16 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
const mxnet::TShape &dshape = in_shape->at(layernorm::kData);
if (!mxnet::ndim_is_known(dshape)) {
return false;
}

int axis = GetRealAxis(param.axis, dshape.ndim());
CHECK(axis >= 0 && axis < dshape.ndim())
<< "Channel axis out of range: axis=" << param.axis;

const index_t channelCount = dshape[axis];

if (!mxnet::ndim_is_known(dshape)) {
return false;
}
SHAPE_ASSIGN_CHECK(*in_shape,
layernorm::kGamma,
mxnet::TShape(Shape1(channelCount)));
Expand Down
9 changes: 7 additions & 2 deletions src/operator/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,15 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
mxnet::ShapeVector *out_shape) {
const PoolingParam &param = nnvm::get<PoolingParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 1U);
const mxnet::TShape &dshape = (*in_shape)[0];
if (!mxnet::ndim_is_known(dshape)) {
return false;
}

if (param.pool_type == pool_enum::kLpPooling) {
CHECK(param.p_value.has_value());
}
const mxnet::TShape &dshape = (*in_shape)[0];

if (param.pooling_convention == pool_enum::kSame) {
CHECK_EQ(dshape.ndim(), 3U)
<< "Pooling: Input data should be 3D in (batch, channel, x)"
Expand All @@ -114,7 +119,7 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
<< "Pooling: Input data should be 3D in (batch, channel, x)"
<< " Or 4D in (batch, channel, y, x) "
<< " Or 5D in (batch, channel, d, y, x)";
if (!mxnet::ndim_is_known(dshape)) return false;

int layout = param.GetLayout(dshape.ndim());
if (param.global_pool) {
mxnet::TShape oshape = dshape;
Expand Down
22 changes: 17 additions & 5 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,9 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& shp = (*in_attrs)[0];
mxnet::TShape& out_shp = (*out_attrs)[0];
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
if (shp.ndim() == -1 && out_shp.ndim() == -1)
if (!mxnet::ndim_is_known(shp) && !mxnet::ndim_is_known(out_shp))
return false; // none of the shapes is known
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
if (out_shp.ndim() >= 0 && shp.ndim() >= 0)
CHECK_EQ(out_shp.ndim(), shp.ndim());
mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1);
Expand Down Expand Up @@ -513,12 +513,12 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs,
const ExpandDimParam& param = nnvm::get<ExpandDimParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
if (!mxnet::ndim_is_known(in_attrs->at(0)) && !mxnet::ndim_is_known(out_attrs->at(0))) {
mxnet::TShape& ishape = (*in_attrs)[0];
mxnet::TShape& oshape = (*out_attrs)[0];
if (!mxnet::ndim_is_known(ishape) && !mxnet::ndim_is_known(oshape)) {
return false;
}

mxnet::TShape& ishape = (*in_attrs)[0];
mxnet::TShape& oshape = (*out_attrs)[0];
int indim = ishape.ndim();
bool unknown_ishape = false;
if (-1 == indim) {
Expand Down Expand Up @@ -1441,6 +1441,9 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& ishape = (*in_attrs)[0];
mxnet::TShape& from_shape = (*in_attrs)[1];
if (!mxnet::ndim_is_known(ishape) || !mxnet::ndim_is_known(from_shape)) {
return false;
}
if (param.axes.ndim() == 0) {
CHECK_EQ(ishape.ndim(), from_shape.ndim())
<< "By default slice_axis performs slice on all axes, but ndim mismatch "
Expand Down Expand Up @@ -1749,6 +1752,9 @@ inline bool RepeatOpShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const mxnet::TShape& ishape = (*in_attrs)[0];
if (!mxnet::ndim_is_known(ishape)) {
return false;
}
int repeats = 0;
dmlc::optional<int> axisOpt;
GetRepeatParams(param, ishape, &repeats, &axisOpt);
Expand Down Expand Up @@ -2427,6 +2433,9 @@ inline bool DepthToSpaceOpShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape expected_out(4, -1);

mxnet::TShape& in_shape = in_attrs->at(0);
if (!mxnet::ndim_is_known(in_shape)) {
return false;
}
int block = param.block_size;
CHECK_NE(block, 0) << "block_size must be a positive integer value";
CHECK_NE(in_shape[1], 0) << "Depth dimension:1 cannot be 0";
Expand Down Expand Up @@ -2591,6 +2600,9 @@ inline bool SpaceToDepthOpShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);

mxnet::TShape& in_shape = in_attrs->at(0);
if (!mxnet::ndim_is_known(in_shape)) {
return false;
}
int block = param.block_size;
CHECK_NE(block, 0) << "block_size must be a positive integer value";
CHECK_NE(in_shape[0], 0)
Expand Down