Skip to content

[Auto Parallel] Add spmd rule No.4、13 for (batch_norm,sync_batch_norm) and their backward ops. #72918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Jun 16, 2025

Conversation

Glencsa
Copy link
Contributor

@Glencsa Glencsa commented May 24, 2025

PR Category

Auto Parallel

PR Types

New features

Description

  • 【开源任务】算子切分推导规则开发,支持更多模型使用自动并行,简化更多用户的分布式开发成本。
  • No.4 batch_norm
    No.13 sync_batch_norm
  • 将除了做batch_norm以外的维度全部强制为Replicated

Copy link

paddle-bot bot commented May 24, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label May 24, 2025
@luotao1 luotao1 changed the title [Auto Parallel] Add spmd rule for batch_norm and batch_norm_grad ops. [Auto Parallel] Add spmd rule No.4 for batch_norm and batch_norm_grad ops. May 26, 2025
@luotao1 luotao1 added the HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务 label May 26, 2025
@Glencsa Glencsa changed the title [Auto Parallel] Add spmd rule No.4 for batch_norm and batch_norm_grad ops. [Auto Parallel] Add spmd rule for No.4(batch_norm, batch_norm_grad) and No.13(sync_batch_norm,sync_batch_norm_grad) ops. Jun 1, 2025
@Glencsa Glencsa changed the title [Auto Parallel] Add spmd rule for No.4(batch_norm, batch_norm_grad) and No.13(sync_batch_norm,sync_batch_norm_grad) ops. [Auto Parallel] Add spmd rule for No.4、13 for (batch_norm,sync_batch_norm) and their backward ops. Jun 9, 2025
@Glencsa Glencsa changed the title [Auto Parallel] Add spmd rule for No.4、13 for (batch_norm,sync_batch_norm) and their backward ops. [Auto Parallel] Add spmd rule No.4、13 for (batch_norm,sync_batch_norm) and their backward ops. Jun 9, 2025
@@ -2614,7 +2785,7 @@ TEST(Topk, Ctor) {

// test forward
// axis = 1
// [0, 1, -1] -> [0, -1, -1], [0, -1, -1]
// [0, -1, -1, 1],[-1],[-1],[-1],[-1] ->[-1 , -1, -1, 1],[1],[1],[1],[1],[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should not modify this annotation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your notice, I will revert my changes in next commit.

const std::string data_format,
const bool use_global_stats,
const bool trainable_statistics) {
return BatchNormInferSpmdBase(x, mean, variance, scale, bias);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need parameter of data_format in BatchNormInferSpmdBase?
if user pass data_format="NHWC" or "NLC", will right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I have fit all the data_format status in the new commit.

const bool is_test,
const bool use_global_stats,
const bool trainable_statistics) {
return BatchNormGradInferSpmdBase(x,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same issue of data_format as in BatchNormInferSpmdBase

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I have fit all the data_format status in the new commit.

@@ -5056,6 +5056,7 @@
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
infer_meta :
func : BatchNormInferMeta
spmd_rule : SyncBatchNormInferSpmd
Copy link
Contributor

@jeff41404 jeff41404 Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the operator of sync_batch_norm_ is used for manual parallelism, and the implementation of operator includes communication, not just a calculation operator. should have spmd rule?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I think you are right, the operator of sync_batch_norm_ cause different GPUs have different batch, and their mean and variance on device should be communication, the tensor can not be sharded. I will remove the spmd_rule for sync_batch_norm_ in next commit.

Comment on lines 164 to 204
VLOG(4) << "Einsum Notation: " << x_axes << "," << mean_axes << ","
<< variance_axes << "," << scale_axes << "," << bias_axes << "-->"
<< out_axes << "," << mean_axes << "," << variance_axes;
VLOG(4) << "X"
<< " shape: [" << str_join(x_shape) << "] "
<< "src_dims_mapping: [" << str_join(x_dist_attr_src.dims_mapping())
<< "] "
<< "dst_dims_mapping: [" << str_join(x_dims_mapping) << "]";
VLOG(4) << "Mean"
<< " shape: [" << str_join(mean_shape) << "] "
<< "src_dims_mapping: [" << str_join(mean_dims_mapping) << "] "
<< "dst_dims_mapping: ["
<< str_join(mean_dist_attr_dst.dims_mapping()) << "]";
VLOG(4) << "Variance"
<< " shape: [" << str_join(variance_shape) << "] "
<< "src_dims_mapping: [" << str_join(variance_dims_mapping) << "] "
<< "dst_dims_mapping: ["
<< str_join(variance_dist_attr_dst.dims_mapping()) << "]";
VLOG(4) << "Scale"
<< " shape: [" << str_join(scale_shape) << "] "
<< "src_dims_mapping: [" << str_join(scale_dims_mapping) << "] "
<< "dst_dims_mapping: ["
<< str_join(scale_dist_attr_dst.dims_mapping()) << "]";
VLOG(4) << "Bias"
<< " shape: [" << str_join(bias_shape) << "] "
<< "src_dims_mapping: [" << str_join(bias_dims_mapping) << "] "
<< "dst_dims_mapping: ["
<< str_join(bias_dist_attr_dst.dims_mapping()) << "]";
VLOG(4) << "Out dims mapping: [" << str_join(out_dist_attr.dims_mapping())
<< "]";
VLOG(4) << "Mean_out dims mapping: ["
<< str_join(mean_dist_attr.dims_mapping()) << "]";
VLOG(4) << "Variance_out dims mapping: ["
<< str_join(variance_dist_attr.dims_mapping()) << "]";
VLOG(4) << "Saved_mean dims mapping: ["
<< str_join(mean_dist_attr.dims_mapping()) << "]";
VLOG(4) << "Saved_variance dims mapping: ["
<< str_join(variance_dist_attr.dims_mapping()) << "]";
VLOG(4) << "Reserve_space dims mapping: ["
<< str_join(reserve_space_dist_attr.dims_mapping()) << "]";
VLOG(4) << std::endl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we use macro LOG_SPMD_INPUT or LOG_SPMD_OUTPUT to simplify log code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will use it to simplify log code in next commit.

Comment on lines 423 to 455
VLOG(4) << "Einsum Notation: " << x_axes << scale_axes << "," << bias_axes
<< "," << mean_out_axes << "," << variance_out_axes << ","
<< saved_mean_axes << "," << saved_variance_axes << ","
<< "-->" << reserve_space_axes << "," << out_grad_axes;
VLOG(4) << "X"
<< " shape: [" << str_join(x_shape) << "] "
<< "src_dims_mapping: [" << str_join(x_dist_attr_src.dims_mapping())
<< "] "
<< "dst_dims_mapping: [" << str_join(x_dims_mapping) << "]";
VLOG(4) << "Mean_out"
<< " shape: [" << str_join(mean_out_shape) << "] "
<< "src_dims_mapping: ["
<< str_join(mean_out.dist_attr().dims_mapping()) << "] "
<< "dst_dims_mapping: [" << str_join(mean_out_attr_dst.dims_mapping())
<< "]";
VLOG(4) << "Variance_out"
<< " shape: [" << str_join(variance_out_shape) << "] "
<< "src_dims_mapping: ["
<< str_join(variance_out.dist_attr().dims_mapping()) << "] "
<< "dst_dims_mapping: ["
<< str_join(variance_out_attr_dst.dims_mapping()) << "]";
VLOG(4) << "Scale"
<< " shape: [" << str_join(scale_shape) << "] "
<< "src_dims_mapping: [" << str_join(scale.dist_attr().dims_mapping())
<< "] "
<< "dst_dims_mapping: [" << str_join(scale_attr_dst.dims_mapping())
<< "]";
VLOG(4) << "Bias"
<< " shape: [" << str_join(bias_shape) << "] "
<< "src_dims_mapping: [" << str_join(bias.dist_attr().dims_mapping())
<< "] "
<< "dst_dims_mapping: [" << str_join(bias_attr_dst.dims_mapping())
<< "]";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we use macro LOG_SPMD_INPUT or LOG_SPMD_OUTPUT to simplify log code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will use it to simplify log code in next commit.

Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 merged commit 4aee08b into PaddlePaddle:develop Jun 16, 2025
49 of 50 checks passed
@Glencsa Glencsa deleted the batch_norm_spmd_rule branch June 30, 2025 07:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants