@@ -995,6 +995,38 @@ def _verify_swap_axis(in_shape, out_shape, dim1, dim2):
995995 # _verify_swap_axis((4, 5), (5, 4), 0, 0)
996996
997997
998+ def test_forward_depth_to_space ():
999+ def verify (shape , blocksize = 2 ):
1000+ x = np .random .uniform (size = shape ).astype ("float32" )
1001+ ref_res = mx .nd .depth_to_space (mx .nd .array (x ), blocksize )
1002+ mx_sym = mx .sym .depth_to_space (mx .sym .var ("x" ), blocksize )
1003+ shape_dict = {"x" : x .shape , }
1004+ mod , _ = relay .frontend .from_mxnet (mx_sym , shape_dict )
1005+ for target , ctx in ctx_list ():
1006+ for kind in ["graph" , "debug" ]:
1007+ intrp = relay .create_executor (kind , mod = mod , ctx = ctx , target = target )
1008+ op_res = intrp .evaluate ()(x )
1009+ tvm .testing .assert_allclose (op_res .asnumpy (), ref_res .asnumpy (), rtol = 1e-3 , atol = 1e-5 )
1010+
1011+ verify ((1 , 18 , 3 , 3 ), 3 )
1012+
1013+
1014+ def test_forward_space_to_depth ():
1015+ def verify (shape , blocksize = 2 ):
1016+ x = np .random .uniform (size = shape ).astype ("float32" )
1017+ ref_res = mx .nd .space_to_depth (mx .nd .array (x ), blocksize )
1018+ mx_sym = mx .sym .space_to_depth (mx .sym .var ("x" ), blocksize )
1019+ shape_dict = {"x" : x .shape , }
1020+ mod , _ = relay .frontend .from_mxnet (mx_sym , shape_dict )
1021+ for target , ctx in ctx_list ():
1022+ for kind in ["graph" , "debug" ]:
1023+ intrp = relay .create_executor (kind , mod = mod , ctx = ctx , target = target )
1024+ op_res = intrp .evaluate ()(x )
1025+ tvm .testing .assert_allclose (op_res .asnumpy (), ref_res .asnumpy (), rtol = 1e-3 , atol = 1e-5 )
1026+
1027+ verify ((1 , 1 , 9 , 9 ), 3 )
1028+
1029+
9981030if __name__ == '__main__' :
9991031 test_forward_mlp ()
10001032 test_forward_vgg ()
@@ -1047,6 +1079,8 @@ def _verify_swap_axis(in_shape, out_shape, dim1, dim2):
10471079 test_forward_instance_norm ()
10481080 test_forward_layer_norm ()
10491081 test_forward_one_hot ()
1082+ test_forward_depth_to_space ()
1083+ test_forward_space_to_depth ()
10501084 test_forward_convolution ()
10511085 test_forward_deconvolution ()
10521086 test_forward_cond ()
0 commit comments