@@ -2812,6 +2812,113 @@ TEST(MeanAll, Ctor) {
2812
2812
check_dim_mapping (backward_info.first [1 ], {});
2813
2813
check_dim_mapping (backward_info.second [0 ], {-1 , -1 });
2814
2814
}
2815
+ TEST (BatchNorm, Ctor) {
2816
+ std::vector<int64_t > mesh_shape = {2 , 2 };
2817
+ std::vector<int64_t > process_ids = {0 , 1 , 2 , 3 };
2818
+ std::vector<std::string> dim_names = {" x" , " y" };
2819
+ ProcessMesh process_mesh (mesh_shape, process_ids, dim_names);
2820
+
2821
+ // test forward
2822
+ // data_format = NCHW
2823
+ // [0, 1, -1, -1],[-1],[-1],[-1],[-1] ->[-1 , 1, -1, -1],[1],[1],[1],[1],[-1]
2824
+ auto x_dist_attr = TensorDistAttr ();
2825
+ x_dist_attr.set_process_mesh (process_mesh);
2826
+ x_dist_attr.set_dims_mapping ({0 , 1 , -1 , -1 });
2827
+ x_dist_attr.set_dynamic_dims ({false , false , false , false });
2828
+ auto one_dim_dist_attr = TensorDistAttr ();
2829
+ one_dim_dist_attr.set_process_mesh (process_mesh);
2830
+ one_dim_dist_attr.set_dims_mapping ({-1 });
2831
+ one_dim_dist_attr.set_dynamic_dims ({false });
2832
+
2833
+ phi::distributed::DistMetaTensor x = phi::distributed::DistMetaTensor (
2834
+ common::make_ddim ({16 , 16 , 16 , 16 }), x_dist_attr);
2835
+ phi::distributed::DistMetaTensor mean = phi::distributed::DistMetaTensor (
2836
+ common::make_ddim ({16 }), one_dim_dist_attr);
2837
+ phi::distributed::DistMetaTensor variance = phi::distributed::DistMetaTensor (
2838
+ common::make_ddim ({16 }), one_dim_dist_attr);
2839
+ phi::distributed::DistMetaTensor scale = phi::distributed::DistMetaTensor (
2840
+ common::make_ddim ({16 }), one_dim_dist_attr);
2841
+ phi::distributed::DistMetaTensor bias = phi::distributed::DistMetaTensor (
2842
+ common::make_ddim ({16 }), one_dim_dist_attr);
2843
+ phi::distributed::SpmdInfo forward_info =
2844
+ phi::distributed::BatchNormInferSpmdStatic (
2845
+ x, mean, variance, scale, bias);
2846
+
2847
+ EXPECT_EQ (forward_info.first .size (), 5UL );
2848
+ EXPECT_EQ (forward_info.second .size (), 6UL );
2849
+ check_dim_mapping (forward_info.first [0 ], {-1 , 1 , -1 , -1 });
2850
+ check_dim_mapping (forward_info.first [1 ], {1 });
2851
+ check_dim_mapping (forward_info.first [2 ], {1 });
2852
+ check_dim_mapping (forward_info.first [3 ], {-1 });
2853
+ check_dim_mapping (forward_info.first [4 ], {-1 });
2854
+ check_dim_mapping (forward_info.second [0 ], {-1 , 1 , -1 , -1 });
2855
+ check_dim_mapping (forward_info.second [1 ], {1 });
2856
+ check_dim_mapping (forward_info.second [2 ], {1 });
2857
+ check_dim_mapping (forward_info.second [3 ], {1 });
2858
+ check_dim_mapping (forward_info.second [4 ], {1 });
2859
+ check_dim_mapping (forward_info.second [5 ], {-1 });
2860
+
2861
+ // test backward
2862
+ // data_format = NCHW
2863
+ // [0, 1, -1, -1],[-1],[-1],[-1],[-1],[-1],[-1],[-1],[0, 1, -1, -1]
2864
+ // ->[-1,1,-1,-1],[-1],[-1]
2865
+ // dst_input: [-1, 1, -1, -1],[-1],[-1],[1],[1],[1],[1],[-1],[-1, 1, -1, -1]
2866
+
2867
+ x = phi::distributed::DistMetaTensor (common::make_ddim ({16 , 16 , 16 , 16 }),
2868
+ x_dist_attr);
2869
+ phi::distributed::DistMetaTensor out_grad = phi::distributed::DistMetaTensor (
2870
+ common::make_ddim ({16 , 16 , 16 , 16 }), x_dist_attr);
2871
+ phi::distributed::DistMetaTensor mean_out = phi::distributed::DistMetaTensor (
2872
+ common::make_ddim ({16 }), one_dim_dist_attr);
2873
+ phi::distributed::DistMetaTensor variance_out =
2874
+ phi::distributed::DistMetaTensor (common::make_ddim ({16 }),
2875
+ one_dim_dist_attr);
2876
+ scale = phi::distributed::DistMetaTensor (common::make_ddim ({16 }),
2877
+ one_dim_dist_attr);
2878
+ bias = phi::distributed::DistMetaTensor (common::make_ddim ({16 }),
2879
+ one_dim_dist_attr);
2880
+ phi::distributed::DistMetaTensor saved_mean =
2881
+ phi::distributed::DistMetaTensor (common::make_ddim ({16 }),
2882
+ one_dim_dist_attr);
2883
+ phi::distributed::DistMetaTensor saved_variance =
2884
+ phi::distributed::DistMetaTensor (common::make_ddim ({16 }),
2885
+ one_dim_dist_attr);
2886
+ phi::distributed::DistMetaTensor reserve_space =
2887
+ phi::distributed::DistMetaTensor (common::make_ddim ({16 }),
2888
+ one_dim_dist_attr);
2889
+ phi::distributed::SpmdInfo backward_info =
2890
+ phi::distributed::BatchNormGradInferSpmd (x,
2891
+ scale,
2892
+ bias,
2893
+ mean_out,
2894
+ variance_out,
2895
+ saved_mean,
2896
+ saved_variance,
2897
+ reserve_space,
2898
+ out_grad,
2899
+ 0.9 ,
2900
+ 0.1 ,
2901
+ " NCHW" ,
2902
+ false ,
2903
+ false ,
2904
+ false );
2905
+
2906
+ EXPECT_EQ (backward_info.first .size (), 9UL );
2907
+ EXPECT_EQ (backward_info.second .size (), 3UL );
2908
+ check_dim_mapping (backward_info.first [0 ], {-1 , 1 , -1 , -1 });
2909
+ check_dim_mapping (backward_info.first [1 ], {-1 });
2910
+ check_dim_mapping (backward_info.first [2 ], {-1 });
2911
+ check_dim_mapping (backward_info.first [3 ], {1 });
2912
+ check_dim_mapping (backward_info.first [4 ], {1 });
2913
+ check_dim_mapping (backward_info.first [5 ], {1 });
2914
+ check_dim_mapping (backward_info.first [6 ], {1 });
2915
+ check_dim_mapping (backward_info.first [7 ], {-1 });
2916
+ check_dim_mapping (backward_info.first [8 ], {-1 , 1 , -1 , -1 });
2917
+
2918
+ check_dim_mapping (backward_info.second [0 ], {-1 , 1 , -1 , -1 });
2919
+ check_dim_mapping (backward_info.second [1 ], {-1 });
2920
+ check_dim_mapping (backward_info.second [2 ], {-1 });
2921
+ }
2815
2922
TEST (Topk, Ctor) {
2816
2923
std::vector<int64_t > mesh_shape = {2 , 2 };
2817
2924
std::vector<int64_t > process_ids = {0 , 1 , 2 , 3 };
0 commit comments