Skip to content

[Auto Parallel] Add spmd rule No.9 for group_norm and group_norm_grad ops. #72946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8dfce3a
add unary ops which have spmd_rule but not add in yaml file.
Glencsa Apr 10, 2025
1d129c2
Merge branch 'spmd_test' into develop
Glencsa Apr 15, 2025
b9c9e6a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa Apr 15, 2025
f24c883
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa Apr 16, 2025
746356c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa Apr 16, 2025
efc91c5
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa Apr 23, 2025
2109cf9
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa May 8, 2025
773fda6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa May 22, 2025
7681189
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa May 23, 2025
f53affa
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Glencsa May 26, 2025
351038e
Add spmd_rule for group_norm ops.
Glencsa May 26, 2025
63d73a0
Add spmd_rule for group_norm ops.
Glencsa May 26, 2025
e9c66c9
add CI test for group_norm.
Glencsa May 26, 2025
02dc6ec
add CI test for group_norm.
Glencsa May 26, 2025
be1b913
fix bug.
Glencsa May 27, 2025
3532256
fix bug(PD_REGISTER_SPMD_RULE not surport string)
Glencsa May 28, 2025
50fdcd9
PD_REGISTER_SPMD_RULE need less than 5?
Glencsa May 28, 2025
048c148
fix bug.
Glencsa May 29, 2025
13865b8
fix bug.
Glencsa May 30, 2025
eb244af
fix bug.
Glencsa Jun 1, 2025
6ad8fdc
fix bug.
Glencsa Jun 2, 2025
84ce2d4
fix bug.
Glencsa Jun 2, 2025
981a7a1
fix bug.
Glencsa Jun 2, 2025
f82d50d
add partial status.
Glencsa Jun 3, 2025
441d091
apply review and resolve conflict
Glencsa Jun 11, 2025
b0ad519
fix bug.
Glencsa Jun 11, 2025
271bd1a
fix bug.
Glencsa Jun 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
338 changes: 338 additions & 0 deletions paddle/phi/infermeta/spmd_rules/group_norm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,338 @@
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/infermeta/spmd_rules/group_norm.h"

#include "glog/logging.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h"
#include "paddle/phi/infermeta/spmd_rules/utils.h"

namespace phi::distributed {

using phi::distributed::auto_parallel::str_join;
// Tensor x support "NCL", "NCHW", "NCDHW", "NLC", "NHWC", "NDHWC".
// default:"NCHW"

SpmdInfo GroupNormInferSpmdBase(const DistMetaTensor& x,
const DistMetaTensor& scale,
const DistMetaTensor& bias) {
// Step0: verify input args based on group_norm logic
auto x_shape = common::vectorize(x.dims());
auto scale_shape = common::vectorize(scale.dims());
auto bias_shape = common::vectorize(bias.dims());
int x_ndim = static_cast<int>(x_shape.size());
int scale_ndim = static_cast<int>(scale_shape.size());
int bias_ndim = static_cast<int>(bias_shape.size());
TensorDistAttr x_dist_attr_src = x.dist_attr();
TensorDistAttr scale_dist_attr_src = scale.dist_attr();
TensorDistAttr bias_dist_attr_src = bias.dist_attr();
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
std::vector<int64_t> scale_dims_mapping = scale_dist_attr_src.dims_mapping();
std::vector<int64_t> bias_dims_mapping = bias_dist_attr_src.dims_mapping();

PADDLE_ENFORCE_GE(
x_ndim,
3,
common::errors::InvalidArgument(
"The ndim of x in group_norm should grater than 2, but got [%d].",
x_ndim));

PADDLE_ENFORCE_LE(
x_ndim,
5,
common::errors::InvalidArgument(
"The ndim of x in group_norm should be less than 6 , but got [%d].",
x_ndim));
PADDLE_ENFORCE_EQ(
scale_ndim,
1,
common::errors::InvalidArgument(
"The ndim of scale in group_norm should be 1, but got [%d].",
scale_ndim));

PADDLE_ENFORCE_EQ(
bias_ndim,
1,
common::errors::InvalidArgument(
"The ndim of bias in group_norm should be 1, but got [%d].",
bias_ndim));
// Step1: Build Einsum Notation
// Only N axis can be sharded.
std::string alphabet = "ijklmnopqrstuvwxyz";
std::string x_axes(x_ndim, '1');
for (int i = 0; i < x_ndim; ++i) {
x_axes[i] = alphabet[i];
}
std::string mean_axes(1, x_axes[0]);
std::string variance_axes(1, x_axes[0]);
// x_axes[0] = alphabet[0];
std::string scale_axes(1, x_axes[0]);
std::string bias_axes(1, x_axes[0]);
// get output notation
std::string out_axes = x_axes;

// Step2: Sharding Propagation
// Step2.1: merge input sharding
for (int i = 1; i < x_ndim; ++i) {
x_dims_mapping[i] = -1;
}
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors({{x_axes, x_dims_mapping}});

// Step2.2: infer output dims mapping
TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src);
TensorDistAttr mean_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src);
TensorDistAttr variance_dist_attr =
CopyTensorDistAttrForOutput(x_dist_attr_src);
out_dist_attr.set_dims_mapping(
GetDimsMappingForAxes(out_axes, axis_to_dim_map));
mean_dist_attr.set_dims_mapping(
GetDimsMappingForAxes(mean_axes, axis_to_dim_map));
variance_dist_attr.set_dims_mapping(
GetDimsMappingForAxes(variance_axes, axis_to_dim_map));

// Step2.3: update input dims mapping
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
TensorDistAttr scale_dist_attr_dst =
CopyTensorDistAttrForOutput(scale.dist_attr());
TensorDistAttr bias_dist_attr_dst =
CopyTensorDistAttrForOutput(bias.dist_attr());
x_dist_attr_dst.set_dims_mapping(x_dims_mapping);

scale_dist_attr_dst.set_dims_mapping({-1});
bias_dist_attr_dst.set_dims_mapping({-1});

// Step2.4. handle input and out tensor partial
// GroupNorm not support
LOG_SPMD_INPUT(x);
LOG_SPMD_INPUT(scale);
LOG_SPMD_INPUT(bias);
LOG_SPMD_OUTPUT(out_dist_attr);
LOG_SPMD_OUTPUT(mean_dist_attr);
LOG_SPMD_OUTPUT(variance_dist_attr);

return {{x_dist_attr_dst, scale_dist_attr_dst, bias_dist_attr_dst},
{out_dist_attr, mean_dist_attr, variance_dist_attr}};
}
SpmdInfo GroupNormInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& scale,
const DistMetaTensor& bias,
float epsilon,
int groups,
const std::string& data_format) {
return GroupNormInferSpmdBase(x, scale, bias);
}

SpmdInfo GroupNormGradInferSpmdBase(const DistMetaTensor& x,
const DistMetaTensor& scale,
const DistMetaTensor& bias,
const DistMetaTensor& y,
const DistMetaTensor& mean,
const DistMetaTensor& variance,
const DistMetaTensor y_grad) {
// Step0: verify input args based on group_norm logic
auto x_shape = common::vectorize(x.dims());
auto scale_shape = common::vectorize(scale.dims());
auto bias_shape = common::vectorize(bias.dims());
auto y_shape = common::vectorize(y.dims());
auto mean_shape = common::vectorize(mean.dims());
auto variance_shape = common::vectorize(variance.dims());
auto y_grad_shape = common::vectorize(y_grad.dims());
int x_ndim = static_cast<int>(x_shape.size());
int scale_ndim = static_cast<int>(scale_shape.size());
int bias_ndim = static_cast<int>(bias_shape.size());
int y_ndim = static_cast<int>(y_shape.size());
int mean_ndim = static_cast<int>(mean_shape.size());
int variance_ndim = static_cast<int>(variance_shape.size());
int y_grad_ndim = static_cast<int>(y_grad_shape.size());
TensorDistAttr x_dist_attr_src = x.dist_attr();
TensorDistAttr scale_dist_attr_src = scale.dist_attr();
TensorDistAttr bias_dist_attr_src = bias.dist_attr();
TensorDistAttr y_dist_attr_src = y.dist_attr();
TensorDistAttr mean_dist_attr_src = mean.dist_attr();
TensorDistAttr variance_dist_attr_src = variance.dist_attr();
TensorDistAttr y_grad_dist_attr_src = mean.dist_attr();
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
std::vector<int64_t> scale_dims_mapping = scale.dist_attr().dims_mapping();
std::vector<int64_t> bias_dims_mapping = bias.dist_attr().dims_mapping();
std::vector<int64_t> y_dims_mapping = scale.dist_attr().dims_mapping();
std::vector<int64_t> mean_dims_mapping = bias.dist_attr().dims_mapping();
std::vector<int64_t> variance_dims_mapping = scale.dist_attr().dims_mapping();
std::vector<int64_t> y_grad_dims_mapping = bias.dist_attr().dims_mapping();

PADDLE_ENFORCE_GE(
x_ndim,
3,
common::errors::InvalidArgument(
"The ndim of x in group_norm should grater than 2, but got [%d].",
x_ndim));

PADDLE_ENFORCE_LE(
x_ndim,
5,
common::errors::InvalidArgument(
"The ndim of x in group_norm should be less than 6 , but got [%d].",
x_ndim));
PADDLE_ENFORCE_EQ(x_ndim,
y_ndim,
common::errors::InvalidArgument(
"The ndim of x and y in group_norm should be equal, "
"but got x:[%d] and y[%d] .",
x_ndim,
y_ndim));
PADDLE_ENFORCE_EQ(
x_ndim,
y_grad_ndim,
common::errors::InvalidArgument(
"The ndim of x and y_grad in group_norm should be equal, "
"but got x:[%d] and y_grad[%d] .",
x_ndim,
y_grad_ndim));
PADDLE_ENFORCE_EQ(
scale_ndim,
1,
common::errors::InvalidArgument(
"The ndim of scale in group_norm should be 1, but got [%d].",
scale_ndim));

PADDLE_ENFORCE_EQ(
bias_ndim,
1,
common::errors::InvalidArgument(
"The ndim of bias in group_norm should be 1, but got [%d].",
bias_ndim));
PADDLE_ENFORCE_EQ(
mean_ndim,
1,
common::errors::InvalidArgument(
"The ndim of mean in group_norm should be 1, but got [%d].",
mean_ndim));
PADDLE_ENFORCE_EQ(
variance_ndim,
1,
common::errors::InvalidArgument(
"The ndim of variance in group_norm should be 1, but got [%d].",
variance_ndim));

// Step1: Build Einsum Notation
// Only N axis can be sharded.
std::string alphabet = "ijklmnopqrstuvwxyz";
// input
std::string x_axes(x_ndim, '1');
std::string y_axes(y_ndim, '1');
std::string y_grad_axes(y_grad_ndim, '1');

for (int i = 0; i < x_ndim; ++i) {
x_axes[i] = alphabet[i];
y_axes[i] = alphabet[i];
y_grad_axes[i] = alphabet[i];
}
std::string scale_axes(1, x_axes[0]);
std::string bias_axes(1, x_axes[0]);
std::string mean_axes(1, x_axes[0]);
std::string variance_axes(1, x_axes[0]);
// output
std::string x_grad_axes = x_axes;
std::string scale_grad_axes(1, x_axes[0]); // C axis
std::string bias_grad_axes(1, x_axes[0]);
// Step2: Sharding Propagation
// Step2.1: merge input sharding
for (int i = 1; i < x_ndim; ++i) {
x_dims_mapping[i] = -1;
}
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors({{x_axes, x_dims_mapping}});

// Step2.2: infer output dims mapping
TensorDistAttr x_grad_dist_attr =
CopyTensorDistAttrForOutput(x_dist_attr_src);
TensorDistAttr scale_grad_dist_attr =
CopyTensorDistAttrForOutput(scale.dist_attr());
TensorDistAttr bias_grad_dist_attr =
CopyTensorDistAttrForOutput(bias.dist_attr());
x_grad_dist_attr.set_dims_mapping(
GetDimsMappingForAxes(x_grad_axes, axis_to_dim_map));
scale_grad_dist_attr.set_dims_mapping({-1});
bias_grad_dist_attr.set_dims_mapping({-1});

// Step2.3: update input dims mapping
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
TensorDistAttr scale_dist_attr_dst =
CopyTensorDistAttrForOutput(scale.dist_attr());
TensorDistAttr bias_dist_attr_dst =
CopyTensorDistAttrForOutput(bias.dist_attr());
TensorDistAttr y_dist_attr_dst = CopyTensorDistAttrForOutput(y.dist_attr());
TensorDistAttr mean_dist_attr_dst =
CopyTensorDistAttrForOutput(mean.dist_attr());
TensorDistAttr variance_dist_attr_dst =
CopyTensorDistAttrForOutput(variance.dist_attr());
TensorDistAttr y_grad_dist_attr_dst =
CopyTensorDistAttrForOutput(y_grad.dist_attr());
x_dist_attr_dst.set_dims_mapping(x_dims_mapping);
y_dist_attr_dst.set_dims_mapping(x_dims_mapping);
y_grad_dist_attr_dst.set_dims_mapping(x_dims_mapping);
mean_dist_attr_dst.set_dims_mapping(
GetDimsMappingForAxes(mean_axes, axis_to_dim_map));
variance_dist_attr_dst.set_dims_mapping(
GetDimsMappingForAxes(variance_axes, axis_to_dim_map));
scale_dist_attr_dst.set_dims_mapping({-1});
bias_dist_attr_dst.set_dims_mapping({-1});

std::vector<int64_t> partial_on_dims;
const auto& dim_mapping = x_dims_mapping;
for (int i = 0; i < x_ndim; ++i) {
auto mapping = dim_mapping[i];
if (mapping != -1) {
partial_on_dims.push_back(mapping);
}
}
scale_grad_dist_attr.set_partial_status(partial_on_dims);
bias_grad_dist_attr.set_partial_status(partial_on_dims);

LOG_SPMD_INPUT(x);
LOG_SPMD_INPUT(scale);
LOG_SPMD_INPUT(bias);
LOG_SPMD_INPUT(y);
LOG_SPMD_INPUT(mean);
LOG_SPMD_INPUT(variance);
LOG_SPMD_INPUT(y_grad);
LOG_SPMD_OUTPUT(x_grad_dist_attr);
LOG_SPMD_OUTPUT(scale_grad_dist_attr);
LOG_SPMD_OUTPUT(bias_grad_dist_attr);

return {{x_dist_attr_dst,
scale_dist_attr_dst,
bias_dist_attr_dst,
y_dist_attr_dst,
mean_dist_attr_dst,
variance_dist_attr_dst,
y_grad_dist_attr_dst},
{x_grad_dist_attr, scale_grad_dist_attr, bias_grad_dist_attr}};
}
SpmdInfo GroupNormGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& scale,
const DistMetaTensor& bias,
const DistMetaTensor& y,
const DistMetaTensor& mean,
const DistMetaTensor& variance,
const DistMetaTensor y_grad,
float epsilon,
int groups,
const std::string& data_format) {
return GroupNormGradInferSpmdBase(x, scale, bias, y, mean, variance, y_grad);
}
} // namespace phi::distributed
49 changes: 49 additions & 0 deletions paddle/phi/infermeta/spmd_rules/group_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"

namespace phi {
namespace distributed {
SpmdInfo GroupNormInferSpmdBase(const DistMetaTensor& x,
const DistMetaTensor& scale,
const DistMetaTensor& bias);
SpmdInfo GroupNormInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& scale,
const DistMetaTensor& bias,
float epsilon = 1e-5,
int groups = -1,
const std::string& data_format = "NCHW");
SpmdInfo GroupNormGradInferSpmdBase(const DistMetaTensor& x,
const DistMetaTensor& scale,
const DistMetaTensor& bias,
const DistMetaTensor& y,
const DistMetaTensor& mean,
const DistMetaTensor& variance,
const DistMetaTensor y_grad);
SpmdInfo GroupNormGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& scale,
const DistMetaTensor& bias,
const DistMetaTensor& y,
const DistMetaTensor& mean,
const DistMetaTensor& variance,
const DistMetaTensor y_grad,
float epsilon = 1e-5,
int groups = -1,
const std::string& data_format = "NCHW");
} // namespace distributed
} // namespace phi
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/spmd_rules/rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,9 @@ PD_REGISTER_SPMD_RULE(pad,
PD_INFER_SPMD(phi::distributed::PadInferSpmd),
PD_INFER_SPMD(phi::distributed::PadGradInferSpmd));

// group_norm
PD_REGISTER_SPMD_RULE(group_norm,
PD_INFER_SPMD(phi::distributed::GroupNormInferSpmdBase));
// nonzero
PD_REGISTER_SPMD_RULE(nonzero,
PD_INFER_SPMD(phi::distributed::NonZeroInferSpmd),
Expand Down
Loading
Loading