Skip to content

Commit 4aee08b

Browse files
authored
[Auto Parallel] Add spmd rule No.4、13 for (batch_norm,sync_batch_norm) and their backward ops. (#72918)
* add unary ops which have spmd_rule but not add in yaml file. * Add spmd_rule for batch_norm ops. * Add spmd_rule for batch_norm and batch_norm_grad. * fix bug * fix bug. * fix bug. * add spmd_rule for sync_natch_norm * add spmd_rule for sync_natch_norm * fix bug. * fix bug. * Add partial status. * fix ci bug. * fix CI bug. * apply review. * fix bug.
1 parent a8dcab8 commit 4aee08b

File tree

9 files changed

+599
-1
lines changed

9 files changed

+599
-1
lines changed

paddle/phi/infermeta/spmd_rules/batch_norm.cc

Lines changed: 427 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/phi/common/scalar.h"
18+
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
19+
#include "paddle/phi/core/distributed/type_defs.h"
20+
21+
namespace phi {
22+
namespace distributed {
23+
SpmdInfo BatchNormInferSpmd(const DistMetaTensor& x,
24+
const DistMetaTensor& mean,
25+
const DistMetaTensor& variance,
26+
const DistMetaTensor& scale,
27+
const DistMetaTensor& bias,
28+
const bool is_test = false,
29+
const float momentum = 0.9,
30+
const float epsilon = 1e-05,
31+
const std::string& data_format = "NCHW",
32+
const bool use_global_stats = false,
33+
const bool trainable_statistics = false);
34+
SpmdInfo BatchNormInferSpmdStatic(const DistMetaTensor& x,
35+
const DistMetaTensor& mean,
36+
const DistMetaTensor& variance,
37+
const DistMetaTensor& scale,
38+
const DistMetaTensor& bias);
39+
40+
SpmdInfo BatchNormGradInferSpmd(const DistMetaTensor& x,
41+
const DistMetaTensor& scale,
42+
const DistMetaTensor& bias,
43+
const DistMetaTensor& mean_out,
44+
const DistMetaTensor& variance_out,
45+
const DistMetaTensor& saved_mean,
46+
const DistMetaTensor& saved_variance,
47+
const DistMetaTensor& reserve_space,
48+
const DistMetaTensor& out_grad,
49+
const float momentum = 0.9,
50+
const float epsilon = 1e-05,
51+
const std::string& data_format = "NCHW",
52+
const bool is_test = false,
53+
const bool use_global_stats = false,
54+
const bool trainable_statistics = false);
55+
56+
} // namespace distributed
57+
} // namespace phi

paddle/phi/infermeta/spmd_rules/rules.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ PD_REGISTER_SPMD_RULE(
6666
fused_rotary_position_embedding,
6767
PD_INFER_SPMD(phi::distributed::FusedRopeInferSpmd),
6868
PD_INFER_SPMD(phi::distributed::FusedRopeInferSpmdReverse));
69-
7069
// replicated rule /* for unittest */
7170
PD_REGISTER_SPMD_RULE(
7271
replicated,
@@ -525,6 +524,9 @@ PD_REGISTER_SPMD_RULE(
525524
PD_REGISTER_SPMD_RULE(mean_all,
526525
PD_INFER_SPMD(phi::distributed::MeanAllInferSpmd),
527526
PD_INFER_SPMD(phi::distributed::MeanAllGradInferSpmd));
527+
// batch_norm
528+
PD_REGISTER_SPMD_RULE(
529+
batch_norm, PD_INFER_SPMD(phi::distributed::BatchNormInferSpmdStatic));
528530

529531
// layer_norm
530532
PD_REGISTER_SPMD_RULE(

paddle/phi/infermeta/spmd_rules/rules.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License. */
1919
#include "paddle/phi/infermeta/spmd_rules/argmax.h"
2020
#include "paddle/phi/infermeta/spmd_rules/argmin.h"
2121
#include "paddle/phi/infermeta/spmd_rules/argsort.h"
22+
#include "paddle/phi/infermeta/spmd_rules/batch_norm.h"
2223
#include "paddle/phi/infermeta/spmd_rules/c_embedding.h"
2324
#include "paddle/phi/infermeta/spmd_rules/c_softmax_with_cross_entropy.h"
2425
#include "paddle/phi/infermeta/spmd_rules/c_softmax_with_multi_label_cross_entropy.h"

paddle/phi/ops/yaml/inconsistent/dygraph_backward.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
infer_meta :
7777
func : GeneralTernaryGradInferMeta
7878
param : [x, scale, bias]
79+
spmd_rule : BatchNormGradInferSpmd
7980
kernel :
8081
func : batch_norm_grad
8182
data_type : out_grad

paddle/phi/ops/yaml/inconsistent/dygraph_ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
5151
infer_meta:
5252
func : BatchNormInferMeta
53+
spmd_rule : BatchNormInferSpmd
5354
kernel :
5455
func : batch_norm
5556
data_type : x

paddle/phi/ops/yaml/inconsistent/static_backward.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
infer_meta :
8989
func : GeneralTernaryGradInferMeta
9090
param : [x, scale, bias]
91+
spmd_rule : BatchNormGradInferSpmd
9192
kernel :
9293
func : batch_norm_grad
9394
data_type : out_grad

paddle/phi/ops/yaml/inconsistent/static_ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
7272
infer_meta:
7373
func : BatchNormInferMeta
74+
spmd_rule : BatchNormInferSpmd
7475
kernel :
7576
func : batch_norm
7677
data_type : x

test/cpp/auto_parallel/spmd_rule_test.cc

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2812,6 +2812,113 @@ TEST(MeanAll, Ctor) {
28122812
check_dim_mapping(backward_info.first[1], {});
28132813
check_dim_mapping(backward_info.second[0], {-1, -1});
28142814
}
2815+
TEST(BatchNorm, Ctor) {
2816+
std::vector<int64_t> mesh_shape = {2, 2};
2817+
std::vector<int64_t> process_ids = {0, 1, 2, 3};
2818+
std::vector<std::string> dim_names = {"x", "y"};
2819+
ProcessMesh process_mesh(mesh_shape, process_ids, dim_names);
2820+
2821+
// test forward
2822+
// data_format = NCHW
2823+
// [0, 1, -1, -1],[-1],[-1],[-1],[-1] ->[-1 , 1, -1, -1],[1],[1],[1],[1],[-1]
2824+
auto x_dist_attr = TensorDistAttr();
2825+
x_dist_attr.set_process_mesh(process_mesh);
2826+
x_dist_attr.set_dims_mapping({0, 1, -1, -1});
2827+
x_dist_attr.set_dynamic_dims({false, false, false, false});
2828+
auto one_dim_dist_attr = TensorDistAttr();
2829+
one_dim_dist_attr.set_process_mesh(process_mesh);
2830+
one_dim_dist_attr.set_dims_mapping({-1});
2831+
one_dim_dist_attr.set_dynamic_dims({false});
2832+
2833+
phi::distributed::DistMetaTensor x = phi::distributed::DistMetaTensor(
2834+
common::make_ddim({16, 16, 16, 16}), x_dist_attr);
2835+
phi::distributed::DistMetaTensor mean = phi::distributed::DistMetaTensor(
2836+
common::make_ddim({16}), one_dim_dist_attr);
2837+
phi::distributed::DistMetaTensor variance = phi::distributed::DistMetaTensor(
2838+
common::make_ddim({16}), one_dim_dist_attr);
2839+
phi::distributed::DistMetaTensor scale = phi::distributed::DistMetaTensor(
2840+
common::make_ddim({16}), one_dim_dist_attr);
2841+
phi::distributed::DistMetaTensor bias = phi::distributed::DistMetaTensor(
2842+
common::make_ddim({16}), one_dim_dist_attr);
2843+
phi::distributed::SpmdInfo forward_info =
2844+
phi::distributed::BatchNormInferSpmdStatic(
2845+
x, mean, variance, scale, bias);
2846+
2847+
EXPECT_EQ(forward_info.first.size(), 5UL);
2848+
EXPECT_EQ(forward_info.second.size(), 6UL);
2849+
check_dim_mapping(forward_info.first[0], {-1, 1, -1, -1});
2850+
check_dim_mapping(forward_info.first[1], {1});
2851+
check_dim_mapping(forward_info.first[2], {1});
2852+
check_dim_mapping(forward_info.first[3], {-1});
2853+
check_dim_mapping(forward_info.first[4], {-1});
2854+
check_dim_mapping(forward_info.second[0], {-1, 1, -1, -1});
2855+
check_dim_mapping(forward_info.second[1], {1});
2856+
check_dim_mapping(forward_info.second[2], {1});
2857+
check_dim_mapping(forward_info.second[3], {1});
2858+
check_dim_mapping(forward_info.second[4], {1});
2859+
check_dim_mapping(forward_info.second[5], {-1});
2860+
2861+
// test backward
2862+
// data_format = NCHW
2863+
// [0, 1, -1, -1],[-1],[-1],[-1],[-1],[-1],[-1],[-1],[0, 1, -1, -1]
2864+
// ->[-1,1,-1,-1],[-1],[-1]
2865+
// dst_input: [-1, 1, -1, -1],[-1],[-1],[1],[1],[1],[1],[-1],[-1, 1, -1, -1]
2866+
2867+
x = phi::distributed::DistMetaTensor(common::make_ddim({16, 16, 16, 16}),
2868+
x_dist_attr);
2869+
phi::distributed::DistMetaTensor out_grad = phi::distributed::DistMetaTensor(
2870+
common::make_ddim({16, 16, 16, 16}), x_dist_attr);
2871+
phi::distributed::DistMetaTensor mean_out = phi::distributed::DistMetaTensor(
2872+
common::make_ddim({16}), one_dim_dist_attr);
2873+
phi::distributed::DistMetaTensor variance_out =
2874+
phi::distributed::DistMetaTensor(common::make_ddim({16}),
2875+
one_dim_dist_attr);
2876+
scale = phi::distributed::DistMetaTensor(common::make_ddim({16}),
2877+
one_dim_dist_attr);
2878+
bias = phi::distributed::DistMetaTensor(common::make_ddim({16}),
2879+
one_dim_dist_attr);
2880+
phi::distributed::DistMetaTensor saved_mean =
2881+
phi::distributed::DistMetaTensor(common::make_ddim({16}),
2882+
one_dim_dist_attr);
2883+
phi::distributed::DistMetaTensor saved_variance =
2884+
phi::distributed::DistMetaTensor(common::make_ddim({16}),
2885+
one_dim_dist_attr);
2886+
phi::distributed::DistMetaTensor reserve_space =
2887+
phi::distributed::DistMetaTensor(common::make_ddim({16}),
2888+
one_dim_dist_attr);
2889+
phi::distributed::SpmdInfo backward_info =
2890+
phi::distributed::BatchNormGradInferSpmd(x,
2891+
scale,
2892+
bias,
2893+
mean_out,
2894+
variance_out,
2895+
saved_mean,
2896+
saved_variance,
2897+
reserve_space,
2898+
out_grad,
2899+
0.9,
2900+
0.1,
2901+
"NCHW",
2902+
false,
2903+
false,
2904+
false);
2905+
2906+
EXPECT_EQ(backward_info.first.size(), 9UL);
2907+
EXPECT_EQ(backward_info.second.size(), 3UL);
2908+
check_dim_mapping(backward_info.first[0], {-1, 1, -1, -1});
2909+
check_dim_mapping(backward_info.first[1], {-1});
2910+
check_dim_mapping(backward_info.first[2], {-1});
2911+
check_dim_mapping(backward_info.first[3], {1});
2912+
check_dim_mapping(backward_info.first[4], {1});
2913+
check_dim_mapping(backward_info.first[5], {1});
2914+
check_dim_mapping(backward_info.first[6], {1});
2915+
check_dim_mapping(backward_info.first[7], {-1});
2916+
check_dim_mapping(backward_info.first[8], {-1, 1, -1, -1});
2917+
2918+
check_dim_mapping(backward_info.second[0], {-1, 1, -1, -1});
2919+
check_dim_mapping(backward_info.second[1], {-1});
2920+
check_dim_mapping(backward_info.second[2], {-1});
2921+
}
28152922
TEST(Topk, Ctor) {
28162923
std::vector<int64_t> mesh_shape = {2, 2};
28172924
std::vector<int64_t> process_ids = {0, 1, 2, 3};

0 commit comments

Comments
 (0)