Skip to content

[Auto Parallel] Add spmd rule No.6 for unique ops. #72824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 29, 2025

Conversation

ooooo-create
Copy link
Contributor

@ooooo-create ooooo-create commented May 20, 2025

PR Category

Auto Parallel

PR Types

New features

Description

Copy link

paddle-bot bot commented May 20, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label May 20, 2025
@ooooo-create ooooo-create changed the title [Auto Parallel] Add spmd rule No.6 for unique and unique_grad ops. [Auto Parallel] Add spmd rule No.6 for unique ops. May 20, 2025
@luotao1 luotao1 added the HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务 label May 21, 2025
TensorDistAttr indices_dist_attr_dst = TensorDistAttr();
if (return_index) {
indices_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
indices_dist_attr_dst.set_dims_mapping({-1});
Copy link
Contributor

@Yeenyeong Yeenyeong May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the variable "axis" is not "None", the number of dimension of "indices" should be the same as that of the input "x" (so are "inverse" and "counts").
It does not make sense to simply set "dims_mapping" as {-1}.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the variable "axis" is not "None", according to the UniqueRawInferMeta,they are 1D Tensor Along axis
图片

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh you're right! I'm sorry to misunderstand the operator. Thanks for helping me understand it correctly!

bool return_index,
bool return_inverse,
bool return_counts,
const std::vector<int>& axis);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we do not need this interface.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks,I will delete it later.

bool return_counts,
const std::vector<int>& axis,
DataType dtype,
bool is_sorted);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should we need a distinct "Static" interface ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because I see the unique op in dynamic mode and static mode has different parameters. In my mind, auto parallel is support with two modes, but i'am don't sure how framework to do it, so I add this interface with 'static' suffex and register in 'paddle/phi/ops/yaml/inconsistent/static_ops.yaml'. It's only my guess. Can you tell me more information about it, will I need delete it? Thanks !

Comment on lines 45 to 68
# return_index=True, return_inverse=True, return_counts=True, axis={}
# [0, -1] --> [-1,-1], [-1], [-1], [-1], [-1]
# self.x_dist_tensor_spec.set_dims_mapping([0, -1])
# result_dist_attrs = self.rule.infer_forward(
# self.x_dist_tensor_spec,
# self.attrs["return_index"],
# self.attrs["return_inverse"],
# self.attrs["return_counts"],
# self.attrs["axis"],
# self.attrs['dtype'],
# )

# self.assertEqual(len(result_dist_attrs), 2)
# inferred_input_dist_attrs = result_dist_attrs[0]
# inferred_output_dist_attrs = result_dist_attrs[1]

# self.assertEqual(len(inferred_input_dist_attrs), 1)
# self.assertEqual(len(inferred_output_dist_attrs), 4)

# self.assertEqual(inferred_input_dist_attrs[0].dims_mapping, [-1, -1])
# self.assertEqual(inferred_output_dist_attrs[0].dims_mapping, [-1])
# self.assertEqual(inferred_output_dist_attrs[1].dims_mapping, [-1])
# self.assertEqual(inferred_output_dist_attrs[2].dims_mapping, [-1])
# self.assertEqual(inferred_output_dist_attrs[3].dims_mapping, [-1])
Copy link
Contributor

@jeff41404 jeff41404 May 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we open these annotated codes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I forget it, directly open it will wrong, because it will put the axis with empty list, it's difficult to decide what the type of it to cpp, std::vector or std::<int_64>. so it closed it now, because it's has tested in cpp. my be adding a default parameter can open this test, I will try it later.

Copy link
Contributor Author

@ooooo-create ooooo-create May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found in this pr #57877 (comment) , the problem has already emerged. Now I add support empty list in parse_single_pyobject,and force to vector<int64>, so support std::vector<int> InferSpmdContext::AttrAt(size_t idx) const to deal it to pure int.
图片

图片

@ooooo-create ooooo-create requested a review from jeff41404 May 27, 2025 09:18
Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@luotao1 luotao1 merged commit 44de935 into PaddlePaddle:develop May 29, 2025
117 of 122 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务 skip-ci: xpu
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants