-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.50】 为 Paddle 新增 slice 的 spmd 切分推导规则 #57866
【Hackathon 5th No.50】 为 Paddle 新增 slice 的 spmd 切分推导规则 #57866
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@pkuzyc 老师您好,CI已过,可以review了 |
const std::vector<int>& starts, | ||
const std::vector<int>& ends, | ||
const std::vector<int64_t>& infer_flags, | ||
const std::vector<int64_t>& decrease_axis) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码里面加上注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pkuzyc 本地代码注释大体已加好,但是,涉及到下面两个点的地方,可能需要修改。麻烦老师先看一下~~
|
||
for (int i = 0; i < static_cast<int>(axes.size()); i++) { | ||
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; | ||
input_axes[axis] = special_axes[i]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么需要 special_axes?out_axes 里面切分的维度是 '1',对应的维度不会传到 output 上。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
老师您好。这里借鉴的是split的切分推导规则。 split
只涉及一个
axis
,它用保留的 k
去做特殊标记。类比过来, slice
涉及多个 axis
,所以需要多个保留的(没用到的)字母去做特殊标记。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我觉得 slice 可以不用特殊标记,在最后的 log 里把 axes 打一下吧。其他规则里面也会用 'k' 做特殊标记,所以 split 里也用了 'k' 标记。slice 因为有多个 axis,不像 split 只有一个特殊维度,这么标的话也看不出哪些是特殊维度。
for (int i = 0; i < static_cast<int>(axes.size()); i++) { | ||
int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; | ||
out_axes[axis] = special_axes[i]; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
逆向推导的notation和正向保持一致,不一样的话会有点迷惑。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里也是参照的 split
。我看了下 split
的日志,有一个 notation
是这样的。
- 正向:abck-->abc1
- 逆向:abck-->abck
其中, k
是特殊字母。这种算不算保持一致呀?如果不算,按这个例子来说,应该都改成什么样子呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
slice 这里都用 abcd --> abc1 这种吧,主要还是因为有多个特殊维度,感觉这样可以直接看出来哪些是切分的。split 我之后看看需不需要改下,这样确实不大一致。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
const std::vector<int64_t>& axes, | ||
const std::vector<int>& starts, | ||
const std::vector<int>& ends, | ||
const std::vector<int64_t>& infer_flags, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果一些attr 对切分推导无影响,只是为了和 phi ymal 中定义对齐, 用注释说明一下,下个pr 里可以更新一下
…7866) * Add spmd segmentation and derivation rules for slice for Paddle * fix bugs * fix bugs * add unit test code * modified: test/auto_parallel/spmd_rules/CMakeLists.txt * test * fix bugs * fix bugs
PR types
Others
PR changes
Others
Description
为 Paddle 新增 slice 的 spmd 切分推导规则
#57262