Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.50】 为 Paddle 新增 slice 的 spmd 切分推导规则 #57866

Merged
merged 11 commits into from
Oct 27, 2023

Conversation

WintersMontagne10335
Copy link
Contributor

@WintersMontagne10335 WintersMontagne10335 commented Oct 3, 2023

PR types

Others

PR changes

Others

Description

为 Paddle 新增 slice 的 spmd 切分推导规则
#57262

@paddle-bot
Copy link

paddle-bot bot commented Oct 3, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Oct 3, 2023
@WintersMontagne10335
Copy link
Contributor Author

@pkuzyc 老师您好,CI已过,可以review了

const std::vector<int>& starts,
const std::vector<int>& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码里面加上注释

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pkuzyc 本地代码注释大体已加好,但是,涉及到下面两个点的地方,可能需要修改。麻烦老师先看一下~~


for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
input_axes[axis] = special_axes[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么需要 special_axes?out_axes 里面切分的维度是 '1',对应的维度不会传到 output 上。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

老师您好。这里借鉴的是split的切分推导规则split 只涉及一个
axis ,它用保留的 k 去做特殊标记。类比过来, slice 涉及多个 axis ,所以需要多个保留的(没用到的)字母去做特殊标记。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得 slice 可以不用特殊标记,在最后的 log 里把 axes 打一下吧。其他规则里面也会用 'k' 做特殊标记,所以 split 里也用了 'k' 标记。slice 因为有多个 axis,不像 split 只有一个特殊维度,这么标的话也看不出哪些是特殊维度。

for (int i = 0; i < static_cast<int>(axes.size()); i++) {
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i];
out_axes[axis] = special_axes[i];
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

逆向推导的notation和正向保持一致,不一样的话会有点迷惑。

Copy link
Contributor Author

@WintersMontagne10335 WintersMontagne10335 Oct 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也是参照的 split 。我看了下 split 的日志,有一个 notation 是这样的。

  • 正向:abck-->abc1
  • 逆向:abck-->abck

其中, k 是特殊字母。这种算不算保持一致呀?如果不算,按这个例子来说,应该都改成什么样子呢?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slice 这里都用 abcd --> abc1 这种吧,主要还是因为有多个特殊维度,感觉这样可以直接看出来哪些是切分的。split 我之后看看需不需要改下,这样确实不大一致。

Copy link
Contributor

@JZ-LIANG JZ-LIANG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

const std::vector<int64_t>& axes,
const std::vector<int>& starts,
const std::vector<int>& ends,
const std::vector<int64_t>& infer_flags,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果一些attr 对切分推导无影响,只是为了和 phi ymal 中定义对齐, 用注释说明一下,下个pr 里可以更新一下

@JZ-LIANG JZ-LIANG merged commit b83ce89 into PaddlePaddle:develop Oct 27, 2023
28 checks passed
@WintersMontagne10335
Copy link
Contributor Author

@JZ-LIANG @pkuzyc 老师要不这个PR先revert一下吧。我今天改52题的时候,突然想到,infer_flagsdecrease_axis 虽然不影响切分结果(单测绕过 空list 了),但是实际场景是可能为 空list 的,所以可能会出错。
黑客松12月15日结束,在那之前,那个bug能修好就行。

danleifeng pushed a commit to danleifeng/Paddle that referenced this pull request Nov 14, 2023
…7866)

* Add spmd segmentation and derivation rules for slice for Paddle

* fix bugs

* fix bugs

* add unit test code

* modified:   test/auto_parallel/spmd_rules/CMakeLists.txt

* test

* fix bugs

* fix bugs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants