@@ -686,6 +686,7 @@ struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
686686 Array<IndexExpr> padding;
687687 Array<IndexExpr> dilation;
688688 tvm::String layout;
689+ tvm::String out_layout;
689690 bool ceil_mode;
690691
691692 TVM_DECLARE_ATTRS (MaxPool2DAttrs, " relay.attrs.MaxPool2DAttrs" ) {
@@ -709,6 +710,13 @@ struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
709710 " 'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
710711 " dimensions respectively. Pooling is applied on the 'H' and"
711712 " 'W' dimensions." );
713+ TVM_ATTR_FIELD (out_layout)
714+ .set_default (" " )
715+ .describe (
716+ " Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
717+ " 'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
718+ " dimensions respectively. Pooling is applied on the 'H' and"
719+ " 'W' dimensions." );
712720 TVM_ATTR_FIELD (ceil_mode).set_default (false ).describe (
713721 " When true, will use ceil instead of floor to compute the output shape." );
714722 }
@@ -721,6 +729,7 @@ struct AvgPool2DAttrs : public tvm::AttrsNode<AvgPool2DAttrs> {
721729 Array<IndexExpr> padding;
722730 Array<IndexExpr> dilation;
723731 tvm::String layout;
732+ tvm::String out_layout;
724733 bool ceil_mode;
725734 bool count_include_pad;
726735
@@ -745,6 +754,13 @@ struct AvgPool2DAttrs : public tvm::AttrsNode<AvgPool2DAttrs> {
745754 " 'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
746755 " dimensions respectively. Pooling is applied on the 'H' and"
747756 " 'W' dimensions." );
757+ TVM_ATTR_FIELD (out_layout)
758+ .set_default (" " )
759+ .describe (
760+ " Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
761+ " 'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
762+ " dimensions respectively. Pooling is applied on the 'H' and"
763+ " 'W' dimensions." );
748764 TVM_ATTR_FIELD (ceil_mode).set_default (false ).describe (
749765 " When true, will use ceil instead of floor to compute the output shape." );
750766 TVM_ATTR_FIELD (count_include_pad)
@@ -756,20 +772,29 @@ struct AvgPool2DAttrs : public tvm::AttrsNode<AvgPool2DAttrs> {
756772/* ! \brief Attributes for global pool operator */
757773struct GlobalPool2DAttrs : public tvm ::AttrsNode<GlobalPool2DAttrs> {
758774 tvm::String layout;
775+ tvm::String out_layout;
759776
760777 TVM_DECLARE_ATTRS (GlobalPool2DAttrs, " relay.attrs.GlobalPool2DAttrs" ) {
761778 TVM_ATTR_FIELD (layout).set_default (" NCHW" ).describe (
762779 " Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
763780 " 'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
764781 " dimensions respectively. Pooling is applied on the 'H' and"
765782 " 'W' dimensions." );
783+ TVM_ATTR_FIELD (out_layout)
784+ .set_default (" " )
785+ .describe (
786+ " Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
787+ " 'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
788+ " dimensions respectively. Pooling is applied on the 'H' and"
789+ " 'W' dimensions." );
766790 }
767791};
768792
769793/* ! \brief Attributes for 1d adaptive pool operator */
770794struct AdaptivePool1DAttrs : public tvm ::AttrsNode<AdaptivePool1DAttrs> {
771795 Array<IndexExpr> output_size;
772796 std::string layout;
797+ tvm::String out_layout;
773798
774799 TVM_DECLARE_ATTRS (AdaptivePool1DAttrs, " relay.attrs.AdaptivePool1DAttrs" ) {
775800 TVM_ATTR_FIELD (output_size).set_default (Array<IndexExpr>({})).describe (" Output width." );
@@ -778,13 +803,21 @@ struct AdaptivePool1DAttrs : public tvm::AttrsNode<AdaptivePool1DAttrs> {
778803 " 'N', 'C', 'W' stands for batch, channel, and width"
779804 " dimensions respectively. Pooling is applied on the"
780805 " 'W' dimension." );
806+ TVM_ATTR_FIELD (out_layout)
807+ .set_default (" " )
808+ .describe (
809+ " Dimension ordering of output data. Can be 'NCW', 'NWC', etc."
810+ " 'N', 'C', 'W' stands for batch, channel, and width"
811+ " dimensions respectively. Pooling is applied on the"
812+ " 'W' dimension." );
781813 }
782814};
783815
784816/* ! \brief Attributes for 2d adaptive pool operator */
785817struct AdaptivePool2DAttrs : public tvm ::AttrsNode<AdaptivePool2DAttrs> {
786818 Array<IndexExpr> output_size;
787819 std::string layout;
820+ tvm::String out_layout;
788821
789822 TVM_DECLARE_ATTRS (AdaptivePool2DAttrs, " relay.attrs.AdaptivePool2DAttrs" ) {
790823 TVM_ATTR_FIELD (output_size)
@@ -795,13 +828,21 @@ struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
795828 " 'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
796829 " dimensions respectively. Pooling is applied on the 'H' and"
797830 " 'W' dimensions." );
831+ TVM_ATTR_FIELD (out_layout)
832+ .set_default (" " )
833+ .describe (
834+ " Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
835+ " 'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
836+ " dimensions respectively. Pooling is applied on the 'H' and"
837+ " 'W' dimensions." );
798838 }
799839};
800840
801841/* ! \brief Attributes for 3d adaptive pool operator */
802842struct AdaptivePool3DAttrs : public tvm ::AttrsNode<AdaptivePool3DAttrs> {
803843 Array<IndexExpr> output_size;
804844 std::string layout;
845+ tvm::String out_layout;
805846
806847 TVM_DECLARE_ATTRS (AdaptivePool3DAttrs, " relay.attrs.AdaptivePool3DAttrs" ) {
807848 TVM_ATTR_FIELD (output_size)
@@ -812,6 +853,13 @@ struct AdaptivePool3DAttrs : public tvm::AttrsNode<AdaptivePool3DAttrs> {
812853 " 'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
813854 " dimensions respectively. Pooling is applied on 'D', 'H' and"
814855 " 'W' dimensions." );
856+ TVM_ATTR_FIELD (out_layout)
857+ .set_default (" " )
858+ .describe (
859+ " Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
860+ " 'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
861+ " dimensions respectively. Pooling is applied on 'D', 'H' and"
862+ " 'W' dimensions." );
815863 }
816864};
817865
@@ -822,6 +870,7 @@ struct MaxPool1DAttrs : public tvm::AttrsNode<MaxPool1DAttrs> {
822870 Array<IndexExpr> dilation;
823871 Array<IndexExpr> padding;
824872 std::string layout;
873+ tvm::String out_layout;
825874 bool ceil_mode;
826875
827876 TVM_DECLARE_ATTRS (MaxPool1DAttrs, " relay.attrs.MaxPool1DAttrs" ) {
@@ -844,6 +893,12 @@ struct MaxPool1DAttrs : public tvm::AttrsNode<MaxPool1DAttrs> {
844893 " Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
845894 " 'N', 'C', 'W' stands for batch, channel, and width"
846895 " dimensions respectively. Pooling is applied on the 'W' dimensions." );
896+ TVM_ATTR_FIELD (out_layout)
897+ .set_default (" " )
898+ .describe (
899+ " Dimension ordering of output data. Can be 'NCW', 'NWC', etc."
900+ " 'N', 'C', 'W' stands for batch, channel, and width"
901+ " dimensions respectively. Pooling is applied on the 'W' dimensions." );
847902 TVM_ATTR_FIELD (ceil_mode).set_default (false ).describe (
848903 " When true, will use ceil instead of floor to compute the output shape." );
849904 }
@@ -856,6 +911,7 @@ struct AvgPool1DAttrs : public tvm::AttrsNode<AvgPool1DAttrs> {
856911 Array<IndexExpr> dilation;
857912 Array<IndexExpr> padding;
858913 std::string layout;
914+ tvm::String out_layout;
859915 bool ceil_mode;
860916 bool count_include_pad;
861917
@@ -879,6 +935,12 @@ struct AvgPool1DAttrs : public tvm::AttrsNode<AvgPool1DAttrs> {
879935 " Dimension ordering of input data. Can be 'NCW', 'NHC', etc."
880936 " 'N', 'C', 'W' stands for batch, channel, and width"
881937 " dimensions respectively. Pooling is applied on the 'W' dimension." );
938+ TVM_ATTR_FIELD (out_layout)
939+ .set_default (" " )
940+ .describe (
941+ " Dimension ordering of output data. Can be 'NCW', 'NHC', etc."
942+ " 'N', 'C', 'W' stands for batch, channel, and width"
943+ " dimensions respectively. Pooling is applied on the 'W' dimension." );
882944 TVM_ATTR_FIELD (ceil_mode).set_default (false ).describe (
883945 " When true, will use ceil instead of floor to compute the output shape." );
884946 TVM_ATTR_FIELD (count_include_pad)
@@ -894,6 +956,7 @@ struct MaxPool3DAttrs : public tvm::AttrsNode<MaxPool3DAttrs> {
894956 Array<IndexExpr> dilation;
895957 Array<IndexExpr> padding;
896958 std::string layout;
959+ tvm::String out_layout;
897960 bool ceil_mode;
898961
899962 TVM_DECLARE_ATTRS (MaxPool3DAttrs, " relay.attrs.MaxPool3DAttrs" ) {
@@ -917,6 +980,13 @@ struct MaxPool3DAttrs : public tvm::AttrsNode<MaxPool3DAttrs> {
917980 " 'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
918981 " dimensions respectively. Pooling is applied on the 'D', 'H' and"
919982 " 'W' dimensions." );
983+ TVM_ATTR_FIELD (out_layout)
984+ .set_default (" " )
985+ .describe (
986+ " Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
987+ " 'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
988+ " dimensions respectively. Pooling is applied on the 'D', 'H' and"
989+ " 'W' dimensions." );
920990 TVM_ATTR_FIELD (ceil_mode).set_default (false ).describe (
921991 " When true, will use ceil instead of floor to compute the output shape." );
922992 }
@@ -929,6 +999,7 @@ struct AvgPool3DAttrs : public tvm::AttrsNode<AvgPool3DAttrs> {
929999 Array<IndexExpr> dilation;
9301000 Array<IndexExpr> padding;
9311001 std::string layout;
1002+ tvm::String out_layout;
9321003 bool ceil_mode;
9331004 bool count_include_pad;
9341005
@@ -953,6 +1024,13 @@ struct AvgPool3DAttrs : public tvm::AttrsNode<AvgPool3DAttrs> {
9531024 " 'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
9541025 " dimensions respectively. Pooling is applied on the 'D', 'H' and"
9551026 " 'W' dimensions." );
1027+ TVM_ATTR_FIELD (out_layout)
1028+ .set_default (" " )
1029+ .describe (
1030+ " Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
1031+ " 'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
1032+ " dimensions respectively. Pooling is applied on the 'D', 'H' and"
1033+ " 'W' dimensions." );
9561034 TVM_ATTR_FIELD (ceil_mode).set_default (false ).describe (
9571035 " When true, will use ceil instead of floor to compute the output shape." );
9581036 TVM_ATTR_FIELD (count_include_pad)
0 commit comments