Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions src/relay/op/nn/pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,82 @@ namespace relay {
// relay.nn.pad
TVM_REGISTER_NODE_TYPE(PadAttrs);

Array<Array<Layout> > PadInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
// NOTE: Discard "const" qualifier here.
PadAttrs *params = const_cast<PadAttrs*>(attrs.as<PadAttrs>());

Layout ret;
// If new_in_layouts are defined, this code tries to modify the layout.
bool is_layout_modified = new_in_layouts.defined();
if (new_in_layouts.defined()) {
// Create a map of axis to param_width. For the new layout, a new param_width is generated using
// the map. The new layout is rejected, if the padding is happening along the axis which was
// split.

// 1) Create a map from axis to param_width using old layout.
std::map<std::string, tvm::Array<tvm::Expr>> axis_pad_width;
int index_counter = 0;
CHECK_EQ(new_in_layouts.size(), 1);
CHECK_EQ(old_in_layouts.size(), 1);
for (auto iter_var : old_in_layouts[0]->axes) {
const auto& old_layout_axis = LayoutAxis::Get(iter_var);
axis_pad_width.emplace(old_layout_axis.name(), params->pad_width[index_counter]);
index_counter++;
}

// 2) Create new pad width by walking over the new layout and using the map.
tvm::Array<tvm::Array<tvm::Expr>> new_pad_width;
for (auto iter_var : new_in_layouts[0]->axes) {
const auto& new_layout_axis = LayoutAxis::Get(iter_var);
auto axis_name = new_layout_axis.name();
if (axis_pad_width.count(axis_name) != 0 && new_layout_axis.IsPrimal()) {
// This is primal axis. So, directly use the original pad_width.
new_pad_width.push_back(axis_pad_width.at(axis_name));
} else {
// This is the axis that got split. So, check that pad_width was [0, 0] originally.
const auto& dual_axis = new_layout_axis.ToPrimal();
auto dual_axis_name = dual_axis.name();
CHECK(axis_pad_width.count(dual_axis_name))
<< "Missing axis " << dual_axis << " in " << old_in_layouts[0].name();
new_pad_width.push_back(axis_pad_width.at(dual_axis_name));

// If any pad_width element is not zero, do not change the layout.
for (auto width : axis_pad_width.at(dual_axis_name)) {
if (auto* width_imm = width.as<IntImm>()) {
if (width_imm->value != 0) {
is_layout_modified = false;
}
} else {
is_layout_modified = false;
}
}
}
}

// If the above conditions satisfied, we can set the newly created pad_width and use the new
// layout.
if (is_layout_modified) {
ret = new_in_layouts[0];
params->pad_width = new_pad_width;
}
}

if (!is_layout_modified) {
if (old_in_layouts.defined()) {
CHECK_EQ(old_in_layouts.size(), 1);
ret = old_in_layouts[0];
} else {
ret = Layout::Undef();
}
}

return Array<Array<Layout> >{{ret}, {ret}};
}

bool PadRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
Expand Down Expand Up @@ -133,6 +209,7 @@ RELAY_REGISTER_OP("nn.pad")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("Pad", PadRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", PadInferCorrectLayout)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FTVMCompute>("FTVMCompute", PadCompute);

Expand Down
4 changes: 3 additions & 1 deletion tests/python/frontend/coreml/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ def run_model_checkonly(model_file, model_name='', input_name='image'):
model = cm.models.MLModel(model_file)
x = model_zoo.get_cat_image()
shape_dict = {input_name : x.shape}
mod, params = relay.frontend.from_coreml(model, shape_dict)
# Some Relay passes change operators on the fly. Ensuring that we generate
# new graph for each target.
for target, ctx in ctx_list():
mod, params = relay.frontend.from_coreml(model, shape_dict)
tvm_output = get_tvm_output(mod["main"], x, params, target, ctx)
print(target, ctx, model_name, 'prediction id: ', np.argmax(tvm_output.flat))

Expand Down
115 changes: 115 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,120 @@ def expected():
assert(analysis.alpha_equal(a, b))


def test_alter_layout_pad():
""" Check NCHW, NHWC and corner case for pad layout conversion"""
# Register alter op layout. "level" is used to override the previously registered functions.
@register_alter_op_layout("nn.conv2d", level=112)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)

# Check NCHW conversion.
def before_nchw():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
y = relay.nn.conv2d(x, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1))
ret = relay.nn.pad(y, pad_width=((0, 0), (0, 0), (1, 1), (1, 1)))
y = relay.Function(analysis.free_vars(ret), ret)
return y

def expected_nchw():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
y = relay.layout_transform(x, "NCHW", "NCHW16c")
y = relay.nn.conv2d(y, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW16c")
ret = relay.nn.pad(y, pad_width=((0, 0), (0, 0), (1, 1), (1, 1), (0, 0)))
ret = relay.layout_transform(ret, "NCHW16c", "NCHW")
y = relay.Function(analysis.free_vars(ret), ret)
return y

a = before_nchw()
a = run_opt_pass(a, transform.AlterOpLayout())

b = expected_nchw()
b = run_opt_pass(b, transform.InferType())

assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)

# Check NHWC conversion.
def before_nhwc():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var('weight1')
y = relay.nn.conv2d(x, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout='NHWC')
ret = relay.nn.pad(y, pad_width=((0, 0), (1, 1), (1, 1), (0, 0)))
y = relay.Function(analysis.free_vars(ret), ret)
return y

def expected_nhwc():
x = relay.var("x", shape=(1, 56, 56, 64))
weight1 = relay.var('weight1')
y = relay.layout_transform(x, "NHWC", "NCHW16c")
y = relay.nn.conv2d(y, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW16c")
ret = relay.nn.pad(y, pad_width=((0, 0), (0, 0), (1, 1), (1, 1), (0, 0)))
ret = relay.layout_transform(ret, "NCHW16c", "NHWC")
y = relay.Function(analysis.free_vars(ret), ret)
return y

a = before_nhwc()
a = run_opt_pass(a, transform.AlterOpLayout())

b = expected_nhwc()
b = run_opt_pass(b, transform.InferType())

assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)

# Check that conversion does not happen when padding along split axis..
def before():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
y = relay.nn.conv2d(x, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1))
ret = relay.nn.pad(y, pad_width=((0, 0), (1, 1), (1, 1), (1, 1)))
y = relay.Function(analysis.free_vars(ret), ret)
return y

def expected():
x = relay.var("x", shape=(1, 64, 56, 56))
weight1 = relay.var('weight1')
y = relay.layout_transform(x, "NCHW", "NCHW16c")
y = relay.nn.conv2d(y, weight1,
channels=32,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW16c")
ret = relay.layout_transform(y, "NCHW16c", "NCHW")
ret = relay.nn.pad(ret, pad_width=((0, 0), (1, 1), (1, 1), (1, 1)))
y = relay.Function(analysis.free_vars(ret), ret)
return y

a = before()
a = run_opt_pass(a, transform.AlterOpLayout())

b = expected()
b = run_opt_pass(b, transform.InferType())

assert analysis.alpha_equal(a, b), "Actual = \n" + str(a)


def test_alter_layout_pool():
""" Check NCHW, NHWC pool layout conversion"""
# Register alter op layout. "level" is used to override the previously registered functions.
Expand Down Expand Up @@ -815,5 +929,6 @@ def expected_nhwc():
test_alter_layout_strided_slice()
test_alter_layout_depthwise_conv2d()
test_alter_layout_prelu()
test_alter_layout_pad()
test_alter_layout_pool()
test_alter_layout_sum()