Skip to content

Commit 84353c5

Browse files
committed
Register layout conversion function to more reduce ops
1 parent 18a36a7 commit 84353c5

File tree

2 files changed

+61
-38
lines changed

2 files changed

+61
-38
lines changed

src/relay/op/tensor/reduce.cc

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,14 @@ Array<Integer> GetExcludeAxes(size_t indim, const Array<Integer>& inaxis) {
116116
}
117117

118118
// Return the modified layout for AlterOpLayout pass.
119+
template <typename T>
119120
InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs,
120121
const Array<Layout>& new_in_layouts,
121122
const Array<Layout>& old_in_layouts,
122123
const Array<tvm::relay::Type>& old_in_types) {
123-
const auto* attrs_ptr = attrs.as<ReduceAttrs>();
124+
const auto* attrs_ptr = attrs.as<T>();
124125
ICHECK(attrs_ptr);
125-
ObjectPtr<ReduceAttrs> params = make_object<ReduceAttrs>(*attrs_ptr);
126+
ObjectPtr<T> params = make_object<T>(*attrs_ptr);
126127

127128
// Get the reduce axes.
128129
Array<Array<IndexExpr>> old_in_shapes;
@@ -389,6 +390,7 @@ values over a given axis.
389390
.set_support_level(4)
390391
.add_type_rel("ArgReduce", GenericReduceRel<ArgReduceAttrs>)
391392
.set_attr<FTVMCompute>("FTVMCompute", ArgMaxCompute)
393+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ArgReduceAttrs>)
392394
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
393395

394396
Array<te::Tensor> ArgMinCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
@@ -405,6 +407,7 @@ values over a given axis.
405407
.set_support_level(4)
406408
.add_type_rel("ArgReduce", GenericReduceRel<ArgReduceAttrs>)
407409
.set_attr<FTVMCompute>("FTVMCompute", ArgMinCompute)
410+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ArgReduceAttrs>)
408411
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
409412

410413
Array<te::Tensor> SumCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
@@ -433,7 +436,7 @@ Example::
433436
.set_attrs_type<ReduceAttrs>()
434437
.set_support_level(4)
435438
.add_type_rel("Reduce", ReduceRel)
436-
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout)
439+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
437440
.set_attr<FTVMCompute>("FTVMCompute", SumCompute)
438441
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
439442

@@ -468,6 +471,7 @@ Example::
468471
.set_support_level(4)
469472
.add_type_rel("Reduce", ReduceRel)
470473
.set_attr<FTVMCompute>("FTVMCompute", AllCompute)
474+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
471475
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
472476

473477
Array<te::Tensor> AnyCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
@@ -516,6 +520,7 @@ RELAY_REGISTER_REDUCE_OP("max")
516520
.set_support_level(4)
517521
.add_type_rel("Reduce", ReduceRel)
518522
.set_attr<FTVMCompute>("FTVMCompute", MaxCompute)
523+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
519524
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
520525

521526
Array<te::Tensor> MinCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
@@ -531,6 +536,7 @@ RELAY_REGISTER_REDUCE_OP("min")
531536
.set_support_level(4)
532537
.add_type_rel("Reduce", ReduceRel)
533538
.set_attr<FTVMCompute>("FTVMCompute", MinCompute)
539+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
534540
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
535541

536542
Array<te::Tensor> ProdCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
@@ -551,17 +557,18 @@ Example::
551557
[[1,4],[4,3],[5,2]],
552558
[[7,1],[7,2],[7,3]]]
553559
554-
mean(data, axis=1)
560+
prod(data, axis=1)
555561
[35562240]
556562
557-
mean(data, axis=[1,2])
563+
prod(data, axis=[1,2])
558564
[ 36 480 2058]
559565
560566
)code" TVM_ADD_FILELINE)
561567
.set_attrs_type<ReduceAttrs>()
562568
.set_support_level(4)
563569
.add_type_rel("Reduce", ReduceRel)
564570
.set_attr<FTVMCompute>("FTVMCompute", ProdCompute)
571+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
565572
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
566573

567574
Array<te::Tensor> MeanCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
@@ -600,6 +607,7 @@ Example::
600607
.set_support_level(4)
601608
.add_type_rel("Reduce", ReduceRel)
602609
.set_attr<FTVMCompute>("FTVMCompute", MeanCompute)
610+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
603611
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
604612

605613
bool VarianceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
@@ -675,6 +683,7 @@ RELAY_REGISTER_OP("variance")
675683
.add_argument("mean", "Tensor", "The mean tensor.")
676684
.add_type_rel("Variance", VarianceRel)
677685
.set_attr<FTVMCompute>("FTVMCompute", VarianceCompute)
686+
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ReduceInferCorrectLayout<ReduceAttrs>)
678687
.set_attr<TOpPattern>("TOpPattern", kCommReduce);
679688

680689
} // namespace relay

tests/python/relay/test_pass_convert_op_layout.py

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Test alter op layout pass"""
18+
import pytest
19+
1820
import tvm
1921
from tvm import te
2022

@@ -1925,37 +1927,49 @@ def infer_correct_layout_relu(attrs, new_in_layouts, old_in_layouts, old_in_type
19251927
assert test_infer_correct_layout_flag == True
19261928

19271929

1930+
def test_reduce_op_convert_layout():
1931+
for reduce_op in [relay.argmax, relay.mean, relay.max]:
1932+
1933+
def before():
1934+
x = relay.var("x", shape=(1, 64, 56, 56))
1935+
weight = relay.var("weight", shape=(64, 64, 3, 3))
1936+
y = relay.nn.conv2d(
1937+
x,
1938+
weight,
1939+
channels=64,
1940+
kernel_size=(3, 3),
1941+
padding=(1, 1),
1942+
data_layout="NCHW",
1943+
kernel_layout="OIHW",
1944+
)
1945+
y = reduce_op(y, axis=[2, 3])
1946+
y = relay.Function([x, weight], y)
1947+
return y
1948+
1949+
def expected():
1950+
x = relay.var("x", shape=(1, 64, 56, 56))
1951+
weight = relay.var("weight", shape=(64, 64, 3, 3))
1952+
x = relay.layout_transform(x, "NCHW", "NHWC")
1953+
weight = relay.layout_transform(weight, "OIHW", "HWIO")
1954+
y = relay.nn.conv2d(
1955+
x,
1956+
weight,
1957+
channels=64,
1958+
kernel_size=(3, 3),
1959+
padding=(1, 1),
1960+
data_layout="NHWC",
1961+
kernel_layout="HWIO",
1962+
)
1963+
y = reduce_op(y, axis=[1, 2])
1964+
y = relay.Function(relay.analysis.free_vars(y), y)
1965+
return y
1966+
1967+
a = before()
1968+
a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "default"]}))
1969+
b = run_opt_pass(expected(), transform.InferType())
1970+
1971+
assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
1972+
1973+
19281974
if __name__ == "__main__":
1929-
test_qnn_binary_no_convert_layout()
1930-
test_no_convert_layout()
1931-
test_conv_convert_layout()
1932-
test_conv_nhwc_convert_layout()
1933-
test_conv_bias_pool_convert_layout()
1934-
test_conv_concat_convert_layout()
1935-
test_dual_path_convert_layout()
1936-
test_bn_convert_layout()
1937-
test_slice_like_convert_layout()
1938-
test_transpose_convert_layout()
1939-
test_resnet_convert_layout()
1940-
test_scalar_convert_layout()
1941-
test_conv_bn_convert_layout()
1942-
test_qnn_conv_requantize_convert_layout()
1943-
test_qnn_conv_concat_convert_layout()
1944-
test_qnn_conv_add_convert_layout()
1945-
test_qnn_conv_nhwc_convert_layout()
1946-
test_conv_convert_kernel_layout()
1947-
test_conv_transpose_convert_layout()
1948-
test_conv_roi_align_convert_layout()
1949-
test_conv_roi_pool_convert_layout()
1950-
test_conv_strided_slice_convert_layout()
1951-
test_deformable_conv_bias_pool_convert_layout()
1952-
test_default_keyword()
1953-
test_different_ops_convert_layout()
1954-
test_no_desired_layout()
1955-
test_convert_with_config()
1956-
test_conv_squeeze_convert_layout()
1957-
test_conv_reduce_convert_layout()
1958-
test_conv_strided_slice_axes_convert_layout()
1959-
test_image_resize_convert_layout()
1960-
test_conv_image_resize_convert_layout()
1961-
test_infer_correct_layout()
1975+
pytest.main([__file__])

0 commit comments

Comments
 (0)