Skip to content

Commit

Permalink
Fix Strided Slice Infer Layout (apache#6621)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinthesun authored and Tushar Dey committed Oct 15, 2020
1 parent bbdb1f0 commit cf2795a
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 33 deletions.
111 changes: 78 additions & 33 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2175,52 +2175,97 @@ Array<Array<Layout>> StridedSliceInferCorrectLayout(const Attrs& attrs,
}
}

Array<Integer> new_begin, new_end;
Array<Integer> new_begin, new_end, new_strides;

for (size_t i = 0; i < begin.size(); i++) {
const LayoutAxis& axis = layout[i];
if (!axis.IsPrimal()) {
// original layout that contains splitted axes is not supported
// Handles layout conversion like NHWC -> NCHW
auto old_layout_name = layout.name();
auto new_layout_name = new_layout.name();

if (old_layout_name.rfind(new_layout_name, 0) != 0 &&
new_layout_name.rfind(old_layout_name, 0) != 0) {
if (old_layout_name.size() != new_layout_name.size()) {
// Not support NHW4c -> NCHW
return {{Layout::Undef()}, {Layout::Undef()}};
}
auto factor = new_layout.FactorOf(axis);
if (factor == -1) {
new_begin.push_back(begin[i]);
new_end.push_back(end[i]);
} else {
if (strides.defined() && i < strides.size()) {
auto stride = strides[i];
// arbitrary stride is not supported
if (stride.defined() && stride->value != 1) {
for (size_t i = 0; i < new_layout_name.size(); ++i) {
auto index = layout.IndexOf(new_layout[i]);
if (index == -1) {
return {{Layout::Undef()}, {Layout::Undef()}};
}

size_t new_index = static_cast<size_t>(index);
int64_t bg, ed, st;
if (strides.defined() && new_index < strides.size() && strides[new_index].defined()) {
st = strides[new_index]->value;
} else {
st = 1;
}
if (new_index < begin.size() && begin[new_index].defined()) {
bg = begin[new_index]->value;
} else {
bg = 0;
}
if (new_index < end.size() && end[new_index].defined()) {
ed = end[new_index]->value;
} else {
ed = shape[new_index].as<IntImmNode>()->value;
}

new_begin.push_back(bg);
new_end.push_back(ed);
new_strides.push_back(st);
}
int64_t bg = begin[i].defined() ? begin[i]->value : 0;
int64_t ed;
if (!end[i].defined()) {
ed = shape[i].as<IntImmNode>()->value;
} else if (params->slice_mode == "size") {
if (end[i]->value < 0) {
params->begin = new_begin;
params->end = new_end;
params->strides = new_strides;
layout = new_layout;
}
} else {
for (size_t i = 0; i < begin.size(); i++) {
const LayoutAxis& axis = layout[i];
if (!axis.IsPrimal()) {
// original layout that contains splitted axes is not supported
return {{Layout::Undef()}, {Layout::Undef()}};
}
auto factor = new_layout.FactorOf(axis);
if (factor == -1) {
new_begin.push_back(begin[i]);
new_end.push_back(end[i]);
} else {
if (strides.defined() && i < strides.size()) {
auto stride = strides[i];
// arbitrary stride is not supported
if (stride.defined() && stride->value != 1) {
return {{Layout::Undef()}, {Layout::Undef()}};
}
}
int64_t bg = begin[i].defined() ? begin[i]->value : 0;
int64_t ed;
if (!end[i].defined()) {
ed = shape[i].as<IntImmNode>()->value;
} else if (params->slice_mode == "size") {
if (end[i]->value < 0) {
ed = shape[i].as<IntImmNode>()->value;
} else {
ed = bg + end[i]->value;
}
} else {
ed = bg + end[i]->value;
ed = end[i]->value;
}
} else {
ed = end[i]->value;
}

if (bg % factor || ed % factor) {
// transform to original layout
return {{Layout::Undef()}, {Layout::Undef()}};
if (bg % factor || ed % factor) {
// transform to original layout
return {{Layout::Undef()}, {Layout::Undef()}};
}
new_begin.push_back(tvm::Integer(bg / factor));
new_end.push_back(tvm::Integer(ed / factor));
}
new_begin.push_back(tvm::Integer(bg / factor));
new_end.push_back(tvm::Integer(ed / factor));
}
}

layout = new_layout;
params->begin = new_begin;
params->end = new_end;
layout = new_layout;
params->begin = new_begin;
params->end = new_end;
}
}
return {{layout}, {layout}};
}
Expand Down
46 changes: 46 additions & 0 deletions tests/python/relay/test_pass_convert_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,51 @@ def expected():
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_conv_strided_slice_convert_layout():
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var("weight", shape=(64, 64, 3, 3))
y = relay.nn.conv2d(
x,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
)
y = relay.nn.relu(y)
y = relay.strided_slice(y, begin=[0, 1], end=[1, -1, 10], strides=[1, 1, 2, 1])
y = relay.Function([x, weight], y)
return y

def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight = relay.var("weight", shape=(64, 64, 3, 3))
x = relay.layout_transform(x, "NCHW", "NHWC")
weight = relay.layout_transform(weight, "OIHW", "HWIO")
y = relay.nn.conv2d(
x,
weight,
channels=64,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NHWC",
kernel_layout="HWIO",
)
y = relay.nn.relu(y)
y = relay.strided_slice(y, begin=[0, 0, 0, 1], end=[1, 10, 56, -1], strides=[1, 2, 1, 1])
y = relay.layout_transform(y, "NHWC", "NCHW")
y = relay.Function(relay.analysis.free_vars(y), y)
return y

a = before()
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]}))
b = run_opt_pass(expected(), transform.InferType())

assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)


def test_default_keyword():
""" Check that the default keyword selects correct TVM default layout. """

Expand Down Expand Up @@ -1136,6 +1181,7 @@ def expected():
test_conv_convert_kernel_layout()
test_conv_transpose_convert_layout()
test_conv_roi_align_convert_layout()
test_conv_strided_slice_convert_layout()
test_default_keyword()
test_different_ops_convert_layout()
test_no_desired_layout()

0 comments on commit cf2795a

Please sign in to comment.