Skip to content

Commit 4ee534b

Browse files
anijain2305yzhliu
authored andcommitted
[Relay][AlterOpLayout] NHWC to NCHWc pad operator. (#4103)
* [Relay][AlterOpLayout] NHWC to NCHWc pad operator. * Fixing culprit. * Flaky test 1. * Flaky test 2.
1 parent bc54310 commit 4ee534b

File tree

3 files changed

+195
-1
lines changed

3 files changed

+195
-1
lines changed

src/relay/op/nn/pad.cc

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,82 @@ namespace relay {
3636
// relay.nn.pad
3737
TVM_REGISTER_NODE_TYPE(PadAttrs);
3838

39+
Array<Array<Layout> > PadInferCorrectLayout(
40+
const Attrs& attrs,
41+
const Array<Layout>& new_in_layouts,
42+
const Array<Layout>& old_in_layouts,
43+
const Array<Array<IndexExpr>> &old_in_shapes) {
44+
// NOTE: Discard "const" qualifier here.
45+
PadAttrs *params = const_cast<PadAttrs*>(attrs.as<PadAttrs>());
46+
47+
Layout ret;
48+
// If new_in_layouts are defined, this code tries to modify the layout.
49+
bool is_layout_modified = new_in_layouts.defined();
50+
if (new_in_layouts.defined()) {
51+
// Create a map of axis to param_width. For the new layout, a new param_width is generated using
52+
// the map. The new layout is rejected, if the padding is happening along the axis which was
53+
// split.
54+
55+
// 1) Create a map from axis to param_width using old layout.
56+
std::map<std::string, tvm::Array<tvm::Expr>> axis_pad_width;
57+
int index_counter = 0;
58+
CHECK_EQ(new_in_layouts.size(), 1);
59+
CHECK_EQ(old_in_layouts.size(), 1);
60+
for (auto iter_var : old_in_layouts[0]->axes) {
61+
const auto& old_layout_axis = LayoutAxis::Get(iter_var);
62+
axis_pad_width.emplace(old_layout_axis.name(), params->pad_width[index_counter]);
63+
index_counter++;
64+
}
65+
66+
// 2) Create new pad width by walking over the new layout and using the map.
67+
tvm::Array<tvm::Array<tvm::Expr>> new_pad_width;
68+
for (auto iter_var : new_in_layouts[0]->axes) {
69+
const auto& new_layout_axis = LayoutAxis::Get(iter_var);
70+
auto axis_name = new_layout_axis.name();
71+
if (axis_pad_width.count(axis_name) != 0 && new_layout_axis.IsPrimal()) {
72+
// This is primal axis. So, directly use the original pad_width.
73+
new_pad_width.push_back(axis_pad_width.at(axis_name));
74+
} else {
75+
// This is the axis that got split. So, check that pad_width was [0, 0] originally.
76+
const auto& dual_axis = new_layout_axis.ToPrimal();
77+
auto dual_axis_name = dual_axis.name();
78+
CHECK(axis_pad_width.count(dual_axis_name))
79+
<< "Missing axis " << dual_axis << " in " << old_in_layouts[0].name();
80+
new_pad_width.push_back(axis_pad_width.at(dual_axis_name));
81+
82+
// If any pad_width element is not zero, do not change the layout.
83+
for (auto width : axis_pad_width.at(dual_axis_name)) {
84+
if (auto* width_imm = width.as<IntImm>()) {
85+
if (width_imm->value != 0) {
86+
is_layout_modified = false;
87+
}
88+
} else {
89+
is_layout_modified = false;
90+
}
91+
}
92+
}
93+
}
94+
95+
// If the above conditions satisfied, we can set the newly created pad_width and use the new
96+
// layout.
97+
if (is_layout_modified) {
98+
ret = new_in_layouts[0];
99+
params->pad_width = new_pad_width;
100+
}
101+
}
102+
103+
if (!is_layout_modified) {
104+
if (old_in_layouts.defined()) {
105+
CHECK_EQ(old_in_layouts.size(), 1);
106+
ret = old_in_layouts[0];
107+
} else {
108+
ret = Layout::Undef();
109+
}
110+
}
111+
112+
return Array<Array<Layout> >{{ret}, {ret}};
113+
}
114+
39115
bool PadRel(const Array<Type>& types,
40116
int num_inputs,
41117
const Attrs& attrs,
@@ -133,6 +209,7 @@ RELAY_REGISTER_OP("nn.pad")
133209
.add_argument("data", "Tensor", "The input tensor.")
134210
.set_support_level(2)
135211
.add_type_rel("Pad", PadRel)
212+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PadInferCorrectLayout)
136213
.set_attr<TOpPattern>("TOpPattern", kInjective)
137214
.set_attr<FTVMCompute>("FTVMCompute", PadCompute);
138215

tests/python/frontend/coreml/test_forward.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ def run_model_checkonly(model_file, model_name='', input_name='image'):
4747
model = cm.models.MLModel(model_file)
4848
x = model_zoo.get_cat_image()
4949
shape_dict = {input_name : x.shape}
50-
mod, params = relay.frontend.from_coreml(model, shape_dict)
50+
# Some Relay passes change operators on the fly. Ensuring that we generate
51+
# new graph for each target.
5152
for target, ctx in ctx_list():
53+
mod, params = relay.frontend.from_coreml(model, shape_dict)
5254
tvm_output = get_tvm_output(mod["main"], x, params, target, ctx)
5355
print(target, ctx, model_name, 'prediction id: ', np.argmax(tvm_output.flat))
5456

tests/python/relay/test_pass_alter_op_layout.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,120 @@ def expected():
641641
assert(analysis.alpha_equal(a, b))
642642

643643

644+
def test_alter_layout_pad():
645+
""" Check NCHW, NHWC and corner case for pad layout conversion"""
646+
# Register alter op layout. "level" is used to override the previously registered functions.
647+
@register_alter_op_layout("nn.conv2d", level=112)
648+
def alter_conv2d(attrs, inputs, tinfos):
649+
data, weight = inputs
650+
new_attrs = dict(attrs)
651+
new_attrs['data_layout'] = 'NCHW16c'
652+
return relay.nn.conv2d(data, weight, **new_attrs)
653+
654+
# Check NCHW conversion.
655+
def before_nchw():
656+
x = relay.var("x", shape=(1, 64, 56, 56))
657+
weight1 = relay.var('weight1')
658+
y = relay.nn.conv2d(x, weight1,
659+
channels=32,
660+
kernel_size=(3, 3),
661+
padding=(1, 1))
662+
ret = relay.nn.pad(y, pad_width=((0, 0), (0, 0), (1, 1), (1, 1)))
663+
y = relay.Function(analysis.free_vars(ret), ret)
664+
return y
665+
666+
def expected_nchw():
667+
x = relay.var("x", shape=(1, 64, 56, 56))
668+
weight1 = relay.var('weight1')
669+
y = relay.layout_transform(x, "NCHW", "NCHW16c")
670+
y = relay.nn.conv2d(y, weight1,
671+
channels=32,
672+
kernel_size=(3, 3),
673+
padding=(1, 1),
674+
data_layout="NCHW16c")
675+
ret = relay.nn.pad(y, pad_width=((0, 0), (0, 0), (1, 1), (1, 1), (0, 0)))
676+
ret = relay.layout_transform(ret, "NCHW16c", "NCHW")
677+
y = relay.Function(analysis.free_vars(ret), ret)
678+
return y
679+
680+
a = before_nchw()
681+
a = run_opt_pass(a, transform.AlterOpLayout())
682+
683+
b = expected_nchw()
684+
b = run_opt_pass(b, transform.InferType())
685+
686+
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
687+
688+
# Check NHWC conversion.
689+
def before_nhwc():
690+
x = relay.var("x", shape=(1, 56, 56, 64))
691+
weight1 = relay.var('weight1')
692+
y = relay.nn.conv2d(x, weight1,
693+
channels=32,
694+
kernel_size=(3, 3),
695+
padding=(1, 1),
696+
data_layout='NHWC')
697+
ret = relay.nn.pad(y, pad_width=((0, 0), (1, 1), (1, 1), (0, 0)))
698+
y = relay.Function(analysis.free_vars(ret), ret)
699+
return y
700+
701+
def expected_nhwc():
702+
x = relay.var("x", shape=(1, 56, 56, 64))
703+
weight1 = relay.var('weight1')
704+
y = relay.layout_transform(x, "NHWC", "NCHW16c")
705+
y = relay.nn.conv2d(y, weight1,
706+
channels=32,
707+
kernel_size=(3, 3),
708+
padding=(1, 1),
709+
data_layout="NCHW16c")
710+
ret = relay.nn.pad(y, pad_width=((0, 0), (0, 0), (1, 1), (1, 1), (0, 0)))
711+
ret = relay.layout_transform(ret, "NCHW16c", "NHWC")
712+
y = relay.Function(analysis.free_vars(ret), ret)
713+
return y
714+
715+
a = before_nhwc()
716+
a = run_opt_pass(a, transform.AlterOpLayout())
717+
718+
b = expected_nhwc()
719+
b = run_opt_pass(b, transform.InferType())
720+
721+
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
722+
723+
# Check that conversion does not happen when padding along split axis..
724+
def before():
725+
x = relay.var("x", shape=(1, 64, 56, 56))
726+
weight1 = relay.var('weight1')
727+
y = relay.nn.conv2d(x, weight1,
728+
channels=32,
729+
kernel_size=(3, 3),
730+
padding=(1, 1))
731+
ret = relay.nn.pad(y, pad_width=((0, 0), (1, 1), (1, 1), (1, 1)))
732+
y = relay.Function(analysis.free_vars(ret), ret)
733+
return y
734+
735+
def expected():
736+
x = relay.var("x", shape=(1, 64, 56, 56))
737+
weight1 = relay.var('weight1')
738+
y = relay.layout_transform(x, "NCHW", "NCHW16c")
739+
y = relay.nn.conv2d(y, weight1,
740+
channels=32,
741+
kernel_size=(3, 3),
742+
padding=(1, 1),
743+
data_layout="NCHW16c")
744+
ret = relay.layout_transform(y, "NCHW16c", "NCHW")
745+
ret = relay.nn.pad(ret, pad_width=((0, 0), (1, 1), (1, 1), (1, 1)))
746+
y = relay.Function(analysis.free_vars(ret), ret)
747+
return y
748+
749+
a = before()
750+
a = run_opt_pass(a, transform.AlterOpLayout())
751+
752+
b = expected()
753+
b = run_opt_pass(b, transform.InferType())
754+
755+
assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)
756+
757+
644758
def test_alter_layout_pool():
645759
""" Check NCHW, NHWC pool layout conversion"""
646760
# Register alter op layout. "level" is used to override the previously registered functions.
@@ -815,5 +929,6 @@ def expected_nhwc():
815929
test_alter_layout_strided_slice()
816930
test_alter_layout_depthwise_conv2d()
817931
test_alter_layout_prelu()
932+
test_alter_layout_pad()
818933
test_alter_layout_pool()
819934
test_alter_layout_sum()

0 commit comments

Comments
 (0)