@@ -641,6 +641,120 @@ def expected():
641641 assert (analysis .alpha_equal (a , b ))
642642
643643
644+ def test_alter_layout_pad ():
645+ """ Check NCHW, NHWC and corner case for pad layout conversion"""
646+ # Register alter op layout. "level" is used to override the previously registered functions.
647+ @register_alter_op_layout ("nn.conv2d" , level = 112 )
648+ def alter_conv2d (attrs , inputs , tinfos ):
649+ data , weight = inputs
650+ new_attrs = dict (attrs )
651+ new_attrs ['data_layout' ] = 'NCHW16c'
652+ return relay .nn .conv2d (data , weight , ** new_attrs )
653+
654+ # Check NCHW conversion.
655+ def before_nchw ():
656+ x = relay .var ("x" , shape = (1 , 64 , 56 , 56 ))
657+ weight1 = relay .var ('weight1' )
658+ y = relay .nn .conv2d (x , weight1 ,
659+ channels = 32 ,
660+ kernel_size = (3 , 3 ),
661+ padding = (1 , 1 ))
662+ ret = relay .nn .pad (y , pad_width = ((0 , 0 ), (0 , 0 ), (1 , 1 ), (1 , 1 )))
663+ y = relay .Function (analysis .free_vars (ret ), ret )
664+ return y
665+
666+ def expected_nchw ():
667+ x = relay .var ("x" , shape = (1 , 64 , 56 , 56 ))
668+ weight1 = relay .var ('weight1' )
669+ y = relay .layout_transform (x , "NCHW" , "NCHW16c" )
670+ y = relay .nn .conv2d (y , weight1 ,
671+ channels = 32 ,
672+ kernel_size = (3 , 3 ),
673+ padding = (1 , 1 ),
674+ data_layout = "NCHW16c" )
675+ ret = relay .nn .pad (y , pad_width = ((0 , 0 ), (0 , 0 ), (1 , 1 ), (1 , 1 ), (0 , 0 )))
676+ ret = relay .layout_transform (ret , "NCHW16c" , "NCHW" )
677+ y = relay .Function (analysis .free_vars (ret ), ret )
678+ return y
679+
680+ a = before_nchw ()
681+ a = run_opt_pass (a , transform .AlterOpLayout ())
682+
683+ b = expected_nchw ()
684+ b = run_opt_pass (b , transform .InferType ())
685+
686+ assert analysis .alpha_equal (a , b ), "Actual = \n " + str (a )
687+
688+ # Check NHWC conversion.
689+ def before_nhwc ():
690+ x = relay .var ("x" , shape = (1 , 56 , 56 , 64 ))
691+ weight1 = relay .var ('weight1' )
692+ y = relay .nn .conv2d (x , weight1 ,
693+ channels = 32 ,
694+ kernel_size = (3 , 3 ),
695+ padding = (1 , 1 ),
696+ data_layout = 'NHWC' )
697+ ret = relay .nn .pad (y , pad_width = ((0 , 0 ), (1 , 1 ), (1 , 1 ), (0 , 0 )))
698+ y = relay .Function (analysis .free_vars (ret ), ret )
699+ return y
700+
701+ def expected_nhwc ():
702+ x = relay .var ("x" , shape = (1 , 56 , 56 , 64 ))
703+ weight1 = relay .var ('weight1' )
704+ y = relay .layout_transform (x , "NHWC" , "NCHW16c" )
705+ y = relay .nn .conv2d (y , weight1 ,
706+ channels = 32 ,
707+ kernel_size = (3 , 3 ),
708+ padding = (1 , 1 ),
709+ data_layout = "NCHW16c" )
710+ ret = relay .nn .pad (y , pad_width = ((0 , 0 ), (0 , 0 ), (1 , 1 ), (1 , 1 ), (0 , 0 )))
711+ ret = relay .layout_transform (ret , "NCHW16c" , "NHWC" )
712+ y = relay .Function (analysis .free_vars (ret ), ret )
713+ return y
714+
715+ a = before_nhwc ()
716+ a = run_opt_pass (a , transform .AlterOpLayout ())
717+
718+ b = expected_nhwc ()
719+ b = run_opt_pass (b , transform .InferType ())
720+
721+ assert analysis .alpha_equal (a , b ), "Actual = \n " + str (a )
722+
723+ # Check that conversion does not happen when padding along split axis..
724+ def before ():
725+ x = relay .var ("x" , shape = (1 , 64 , 56 , 56 ))
726+ weight1 = relay .var ('weight1' )
727+ y = relay .nn .conv2d (x , weight1 ,
728+ channels = 32 ,
729+ kernel_size = (3 , 3 ),
730+ padding = (1 , 1 ))
731+ ret = relay .nn .pad (y , pad_width = ((0 , 0 ), (1 , 1 ), (1 , 1 ), (1 , 1 )))
732+ y = relay .Function (analysis .free_vars (ret ), ret )
733+ return y
734+
735+ def expected ():
736+ x = relay .var ("x" , shape = (1 , 64 , 56 , 56 ))
737+ weight1 = relay .var ('weight1' )
738+ y = relay .layout_transform (x , "NCHW" , "NCHW16c" )
739+ y = relay .nn .conv2d (y , weight1 ,
740+ channels = 32 ,
741+ kernel_size = (3 , 3 ),
742+ padding = (1 , 1 ),
743+ data_layout = "NCHW16c" )
744+ ret = relay .layout_transform (y , "NCHW16c" , "NCHW" )
745+ ret = relay .nn .pad (ret , pad_width = ((0 , 0 ), (1 , 1 ), (1 , 1 ), (1 , 1 )))
746+ y = relay .Function (analysis .free_vars (ret ), ret )
747+ return y
748+
749+ a = before ()
750+ a = run_opt_pass (a , transform .AlterOpLayout ())
751+
752+ b = expected ()
753+ b = run_opt_pass (b , transform .InferType ())
754+
755+ assert analysis .alpha_equal (a , b ), "Actual = \n " + str (a )
756+
757+
644758def test_alter_layout_pool ():
645759 """ Check NCHW, NHWC pool layout conversion"""
646760 # Register alter op layout. "level" is used to override the previously registered functions.
@@ -815,5 +929,6 @@ def expected_nhwc():
815929 test_alter_layout_strided_slice ()
816930 test_alter_layout_depthwise_conv2d ()
817931 test_alter_layout_prelu ()
932+ test_alter_layout_pad ()
818933 test_alter_layout_pool ()
819934 test_alter_layout_sum ()
0 commit comments