Skip to content

Commit d11bdcd

Browse files
authored
[Op] Do not override specified layout in pooling (2nd PR) (#9328)
* [Op] Do not override specified layout in pooling (2nd PR) * [Op] Do not override specified layout in pooling (2nd PR) * [Op] Do not override specified layout in pooling (2nd PR) * [Op] Do not override specified layout in pooling (2nd PR)
1 parent e62075d commit d11bdcd

File tree

9 files changed

+675
-75
lines changed

9 files changed

+675
-75
lines changed

include/tvm/relay/attrs/nn.h

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
686686
Array<IndexExpr> padding;
687687
Array<IndexExpr> dilation;
688688
tvm::String layout;
689+
tvm::String out_layout;
689690
bool ceil_mode;
690691

691692
TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relay.attrs.MaxPool2DAttrs") {
@@ -709,6 +710,13 @@ struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
709710
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
710711
"dimensions respectively. Pooling is applied on the 'H' and"
711712
"'W' dimensions.");
713+
TVM_ATTR_FIELD(out_layout)
714+
.set_default("")
715+
.describe(
716+
"Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
717+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
718+
"dimensions respectively. Pooling is applied on the 'H' and"
719+
"'W' dimensions.");
712720
TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
713721
"When true, will use ceil instead of floor to compute the output shape.");
714722
}
@@ -721,6 +729,7 @@ struct AvgPool2DAttrs : public tvm::AttrsNode<AvgPool2DAttrs> {
721729
Array<IndexExpr> padding;
722730
Array<IndexExpr> dilation;
723731
tvm::String layout;
732+
tvm::String out_layout;
724733
bool ceil_mode;
725734
bool count_include_pad;
726735

@@ -745,6 +754,13 @@ struct AvgPool2DAttrs : public tvm::AttrsNode<AvgPool2DAttrs> {
745754
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
746755
"dimensions respectively. Pooling is applied on the 'H' and"
747756
"'W' dimensions.");
757+
TVM_ATTR_FIELD(out_layout)
758+
.set_default("")
759+
.describe(
760+
"Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
761+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
762+
"dimensions respectively. Pooling is applied on the 'H' and"
763+
"'W' dimensions.");
748764
TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
749765
"When true, will use ceil instead of floor to compute the output shape.");
750766
TVM_ATTR_FIELD(count_include_pad)
@@ -756,20 +772,29 @@ struct AvgPool2DAttrs : public tvm::AttrsNode<AvgPool2DAttrs> {
756772
/*! \brief Attributes for global pool operator */
757773
struct GlobalPool2DAttrs : public tvm::AttrsNode<GlobalPool2DAttrs> {
758774
tvm::String layout;
775+
tvm::String out_layout;
759776

760777
TVM_DECLARE_ATTRS(GlobalPool2DAttrs, "relay.attrs.GlobalPool2DAttrs") {
761778
TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
762779
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
763780
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
764781
"dimensions respectively. Pooling is applied on the 'H' and"
765782
"'W' dimensions.");
783+
TVM_ATTR_FIELD(out_layout)
784+
.set_default("")
785+
.describe(
786+
"Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
787+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
788+
"dimensions respectively. Pooling is applied on the 'H' and"
789+
"'W' dimensions.");
766790
}
767791
};
768792

769793
/*! \brief Attributes for 1d adaptive pool operator */
770794
struct AdaptivePool1DAttrs : public tvm::AttrsNode<AdaptivePool1DAttrs> {
771795
Array<IndexExpr> output_size;
772796
std::string layout;
797+
tvm::String out_layout;
773798

774799
TVM_DECLARE_ATTRS(AdaptivePool1DAttrs, "relay.attrs.AdaptivePool1DAttrs") {
775800
TVM_ATTR_FIELD(output_size).set_default(Array<IndexExpr>({})).describe("Output width.");
@@ -778,13 +803,21 @@ struct AdaptivePool1DAttrs : public tvm::AttrsNode<AdaptivePool1DAttrs> {
778803
"'N', 'C', 'W' stands for batch, channel, and width"
779804
"dimensions respectively. Pooling is applied on the"
780805
"'W' dimension.");
806+
TVM_ATTR_FIELD(out_layout)
807+
.set_default("")
808+
.describe(
809+
"Dimension ordering of output data. Can be 'NCW', 'NWC', etc."
810+
"'N', 'C', 'W' stands for batch, channel, and width"
811+
"dimensions respectively. Pooling is applied on the"
812+
"'W' dimension.");
781813
}
782814
};
783815

784816
/*! \brief Attributes for 2d adaptive pool operator */
785817
struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
786818
Array<IndexExpr> output_size;
787819
std::string layout;
820+
tvm::String out_layout;
788821

789822
TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relay.attrs.AdaptivePool2DAttrs") {
790823
TVM_ATTR_FIELD(output_size)
@@ -795,13 +828,21 @@ struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
795828
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
796829
"dimensions respectively. Pooling is applied on the 'H' and"
797830
"'W' dimensions.");
831+
TVM_ATTR_FIELD(out_layout)
832+
.set_default("")
833+
.describe(
834+
"Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
835+
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
836+
"dimensions respectively. Pooling is applied on the 'H' and"
837+
"'W' dimensions.");
798838
}
799839
};
800840

801841
/*! \brief Attributes for 3d adaptive pool operator */
802842
struct AdaptivePool3DAttrs : public tvm::AttrsNode<AdaptivePool3DAttrs> {
803843
Array<IndexExpr> output_size;
804844
std::string layout;
845+
tvm::String out_layout;
805846

806847
TVM_DECLARE_ATTRS(AdaptivePool3DAttrs, "relay.attrs.AdaptivePool3DAttrs") {
807848
TVM_ATTR_FIELD(output_size)
@@ -812,6 +853,13 @@ struct AdaptivePool3DAttrs : public tvm::AttrsNode<AdaptivePool3DAttrs> {
812853
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
813854
"dimensions respectively. Pooling is applied on 'D', 'H' and"
814855
"'W' dimensions.");
856+
TVM_ATTR_FIELD(out_layout)
857+
.set_default("")
858+
.describe(
859+
"Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
860+
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
861+
"dimensions respectively. Pooling is applied on 'D', 'H' and"
862+
"'W' dimensions.");
815863
}
816864
};
817865

@@ -822,6 +870,7 @@ struct MaxPool1DAttrs : public tvm::AttrsNode<MaxPool1DAttrs> {
822870
Array<IndexExpr> dilation;
823871
Array<IndexExpr> padding;
824872
std::string layout;
873+
tvm::String out_layout;
825874
bool ceil_mode;
826875

827876
TVM_DECLARE_ATTRS(MaxPool1DAttrs, "relay.attrs.MaxPool1DAttrs") {
@@ -844,6 +893,12 @@ struct MaxPool1DAttrs : public tvm::AttrsNode<MaxPool1DAttrs> {
844893
"Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
845894
"'N', 'C', 'W' stands for batch, channel, and width"
846895
"dimensions respectively. Pooling is applied on the 'W' dimensions.");
896+
TVM_ATTR_FIELD(out_layout)
897+
.set_default("")
898+
.describe(
899+
"Dimension ordering of output data. Can be 'NCW', 'NWC', etc."
900+
"'N', 'C', 'W' stands for batch, channel, and width"
901+
"dimensions respectively. Pooling is applied on the 'W' dimensions.");
847902
TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
848903
"When true, will use ceil instead of floor to compute the output shape.");
849904
}
@@ -856,6 +911,7 @@ struct AvgPool1DAttrs : public tvm::AttrsNode<AvgPool1DAttrs> {
856911
Array<IndexExpr> dilation;
857912
Array<IndexExpr> padding;
858913
std::string layout;
914+
tvm::String out_layout;
859915
bool ceil_mode;
860916
bool count_include_pad;
861917

@@ -879,6 +935,12 @@ struct AvgPool1DAttrs : public tvm::AttrsNode<AvgPool1DAttrs> {
879935
"Dimension ordering of input data. Can be 'NCW', 'NHC', etc."
880936
"'N', 'C', 'W' stands for batch, channel, and width"
881937
"dimensions respectively. Pooling is applied on the 'W' dimension.");
938+
TVM_ATTR_FIELD(out_layout)
939+
.set_default("")
940+
.describe(
941+
"Dimension ordering of output data. Can be 'NCW', 'NHC', etc."
942+
"'N', 'C', 'W' stands for batch, channel, and width"
943+
"dimensions respectively. Pooling is applied on the 'W' dimension.");
882944
TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
883945
"When true, will use ceil instead of floor to compute the output shape.");
884946
TVM_ATTR_FIELD(count_include_pad)
@@ -894,6 +956,7 @@ struct MaxPool3DAttrs : public tvm::AttrsNode<MaxPool3DAttrs> {
894956
Array<IndexExpr> dilation;
895957
Array<IndexExpr> padding;
896958
std::string layout;
959+
tvm::String out_layout;
897960
bool ceil_mode;
898961

899962
TVM_DECLARE_ATTRS(MaxPool3DAttrs, "relay.attrs.MaxPool3DAttrs") {
@@ -917,6 +980,13 @@ struct MaxPool3DAttrs : public tvm::AttrsNode<MaxPool3DAttrs> {
917980
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
918981
"dimensions respectively. Pooling is applied on the 'D', 'H' and"
919982
"'W' dimensions.");
983+
TVM_ATTR_FIELD(out_layout)
984+
.set_default("")
985+
.describe(
986+
"Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
987+
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
988+
"dimensions respectively. Pooling is applied on the 'D', 'H' and"
989+
"'W' dimensions.");
920990
TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
921991
"When true, will use ceil instead of floor to compute the output shape.");
922992
}
@@ -929,6 +999,7 @@ struct AvgPool3DAttrs : public tvm::AttrsNode<AvgPool3DAttrs> {
929999
Array<IndexExpr> dilation;
9301000
Array<IndexExpr> padding;
9311001
std::string layout;
1002+
tvm::String out_layout;
9321003
bool ceil_mode;
9331004
bool count_include_pad;
9341005

@@ -953,6 +1024,13 @@ struct AvgPool3DAttrs : public tvm::AttrsNode<AvgPool3DAttrs> {
9531024
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
9541025
"dimensions respectively. Pooling is applied on the 'D', 'H' and"
9551026
"'W' dimensions.");
1027+
TVM_ATTR_FIELD(out_layout)
1028+
.set_default("")
1029+
.describe(
1030+
"Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
1031+
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
1032+
"dimensions respectively. Pooling is applied on the 'D', 'H' and"
1033+
"'W' dimensions.");
9561034
TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
9571035
"When true, will use ceil instead of floor to compute the output shape.");
9581036
TVM_ATTR_FIELD(count_include_pad)

python/tvm/relay/op/nn/_nn.py

Lines changed: 97 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""Backend compiler related feature registration"""
1919
from __future__ import absolute_import
2020

21-
from tvm import topi
21+
from tvm import topi, relay
2222
from tvm.topi.utils import get_const_tuple
2323

2424
from tvm.runtime import convert
@@ -267,9 +267,6 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts):
267267
result : tvm.relay.Expr
268268
The transformed expr
269269
"""
270-
# pylint: disable=import-outside-toplevel
271-
from tvm import relay
272-
273270
data, weight = inputs
274271

275272
# First check if there is a LayoutConfig scope, and if so, whether
@@ -363,9 +360,6 @@ def convert_conv2d_transpose(attrs, inputs, tinfos, desired_layouts):
363360
result : tvm.relay.Expr
364361
The transformed expr
365362
"""
366-
# pylint: disable=import-outside-toplevel
367-
from tvm import relay
368-
369363
data, weight = inputs
370364
new_attrs = dict(attrs)
371365
assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv2d's inputs"
@@ -446,9 +440,6 @@ def convert_conv3d(attrs, inputs, tinfos, desired_layouts):
446440
result : tvm.relay.Expr
447441
The transformed expr
448442
"""
449-
# pylint: disable=import-outside-toplevel
450-
from tvm import relay
451-
452443
data, weight = inputs
453444
new_attrs = dict(attrs)
454445
assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv3d's inputs"
@@ -515,6 +506,30 @@ def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype):
515506
reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
516507

517508

509+
@reg.register_convert_op_layout("nn.max_pool2d")
510+
def convert_max_pool2d(attrs, inputs, tinfos, desired_layouts):
511+
"""Convert Layout pass registration for max_pool2d op.
512+
Parameters
513+
----------
514+
attrs : tvm.ir.Attrs
515+
Attributes of current pooling
516+
inputs : list of tvm.relay.Expr
517+
The args of the Relay expr to be legalized
518+
tinfos : list of types
519+
List of input and output types
520+
desired_layouts : list of one layout string
521+
layout string defining our desired layout for input and output.
522+
Returns
523+
-------
524+
result : tvm.relay.Expr
525+
The transformed expr
526+
"""
527+
new_attrs = dict(attrs)
528+
new_attrs["layout"] = str(desired_layouts[0])
529+
new_attrs["out_layout"] = str(desired_layouts[0])
530+
return relay.nn.max_pool2d(*inputs, **new_attrs)
531+
532+
518533
# max_pool3d
519534
reg.register_schedule("nn.max_pool3d", strategy.schedule_pool)
520535
reg.register_pattern("nn.max_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)
@@ -530,6 +545,30 @@ def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype):
530545
reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
531546

532547

548+
@reg.register_convert_op_layout("nn.avg_pool2d")
549+
def convert_avg_pool2d(attrs, inputs, tinfos, desired_layouts):
550+
"""Convert Layout pass registration for avg_pool2d op.
551+
Parameters
552+
----------
553+
attrs : tvm.ir.Attrs
554+
Attributes of current pooling
555+
inputs : list of tvm.relay.Expr
556+
The args of the Relay expr to be legalized
557+
tinfos : list of types
558+
List of input and output types
559+
desired_layouts : list of one layout string
560+
layout string defining our desired layout for input and output.
561+
Returns
562+
-------
563+
result : tvm.relay.Expr
564+
The transformed expr
565+
"""
566+
new_attrs = dict(attrs)
567+
new_attrs["layout"] = str(desired_layouts[0])
568+
new_attrs["out_layout"] = str(desired_layouts[0])
569+
return relay.nn.avg_pool2d(*inputs, **new_attrs)
570+
571+
533572
# avg_pool3d
534573
reg.register_schedule("nn.avg_pool3d", strategy.schedule_pool)
535574
reg.register_pattern("nn.avg_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)
@@ -560,11 +599,59 @@ def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype):
560599
reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
561600

562601

602+
@reg.register_convert_op_layout("nn.global_max_pool2d")
603+
def convert_global_max_pool2d(attrs, inputs, tinfos, desired_layouts):
604+
"""Convert Layout pass registration for global_max_pool2d op.
605+
Parameters
606+
----------
607+
attrs : tvm.ir.Attrs
608+
Attributes of current pooling
609+
inputs : list of tvm.relay.Expr
610+
The args of the Relay expr to be legalized
611+
tinfos : list of types
612+
List of input and output types
613+
desired_layouts : list of one layout string
614+
layout string defining our desired layout for input and output.
615+
Returns
616+
-------
617+
result : tvm.relay.Expr
618+
The transformed expr
619+
"""
620+
new_attrs = dict(attrs)
621+
new_attrs["layout"] = str(desired_layouts[0])
622+
new_attrs["out_layout"] = str(desired_layouts[0])
623+
return relay.nn.global_max_pool2d(*inputs, **new_attrs)
624+
625+
563626
# global_avg_pool2d
564627
reg.register_schedule("nn.global_avg_pool2d", strategy.schedule_adaptive_pool)
565628
reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
566629

567630

631+
@reg.register_convert_op_layout("nn.global_avg_pool2d")
632+
def convert_global_avg_pool2d(attrs, inputs, tinfos, desired_layouts):
633+
"""Convert Layout pass registration for global_avg_pool2d op.
634+
Parameters
635+
----------
636+
attrs : tvm.ir.Attrs
637+
Attributes of current pooling
638+
inputs : list of tvm.relay.Expr
639+
The args of the Relay expr to be legalized
640+
tinfos : list of types
641+
List of input and output types
642+
desired_layouts : list of one layout string
643+
layout string defining our desired layout for input and output.
644+
Returns
645+
-------
646+
result : tvm.relay.Expr
647+
The transformed expr
648+
"""
649+
new_attrs = dict(attrs)
650+
new_attrs["layout"] = str(desired_layouts[0])
651+
new_attrs["out_layout"] = str(desired_layouts[0])
652+
return relay.nn.global_avg_pool2d(*inputs, **new_attrs)
653+
654+
568655
# adaptive_max_pool2d
569656
reg.register_schedule("nn.adaptive_max_pool2d", strategy.schedule_adaptive_pool)
570657
reg.register_pattern("nn.adaptive_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
@@ -796,9 +883,6 @@ def convert_deformable_conv2d(attrs, inputs, tinfos, desired_layouts):
796883
result : tvm.relay.Expr
797884
The transformed expr
798885
"""
799-
# pylint: disable=import-outside-toplevel
800-
from tvm import relay
801-
802886
data, offset, weight = inputs
803887
new_attrs = dict(attrs)
804888
for attr in new_attrs:

0 commit comments

Comments
 (0)