-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5th No.51】 为 Paddle 新增 flatten 的 spmd 切分推导规则 #57875
【Hackathon 5th No.51】 为 Paddle 新增 flatten 的 spmd 切分推导规则 #57875
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@pkuzyc 老师您好,ci过了,您有时间的话,麻烦审核一下哈~~ |
@@ -0,0 +1,284 @@ | |||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
把单测加到 Paddle/test/auto_parallel/spmd_rules/CMakeLists.txt里,否则ci里单测运行不到
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
if (axis >= ndim) { | ||
axis = ndim - 1; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
api里没有约定axis超过ndim时取最后一维,不要这么操作,可以直接报错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} else { | ||
tgt_shape[tgt_shape.size() - 1] *= src_shape[i]; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tgt_shape后面似乎没有用?没有用的话去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
老师您好,后面会输出tgt_shape,有用到的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考后面的评论,输出 start_axis 和 stop_axis,tgt_shape 可以去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
<< "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) | ||
<< "]\n Out dims_mapping: [" << str_join(dims_mapping_vec[1]) | ||
<< "]\n\n"; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
把 Out 的输出另用一个VLOG,这样看起来清楚,跑单测的时候输出调试信息看一下:
VLOG(4) << "X dims_mapping_src: [" << str_join(x_dims_mapping)
<< "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]";
VLOG(4) << "Out dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n";
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
|
||
class TestFlattenSPMDRule(unittest.TestCase): | ||
def setUp(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
补一个[8,16,8,24]-->[8,16824]、不同切分状态的单测
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@pkuzyc 已按照review修改代码,ci通过了,老师有时间可以复审一下哈 |
out_dist_attr.set_dims_mapping(dims_mapping_vec[1]); | ||
|
||
VLOG(4) << "FlattenInferSpmd: X shape: [" << str_join(src_shape) | ||
<< "] Out shape: [" << str_join(tgt_shape) << "]"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里输出 start_axis 和 stop_axis,不要输出 tgt_shape 了,把前面 tgt_shape 的计算也去掉。start_axis 和 stop_axis 是组网给定的,推导的时候用的也是这两个,没有用到 tgt_shape,VLOG 还是输出推导直接用到的内容,方便调试。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} else { | ||
tgt_shape[tgt_shape.size() - 1] *= src_shape[i]; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考后面的评论,输出 start_axis 和 stop_axis,tgt_shape 可以去掉
@pkuzyc 修改完毕,ci通过了 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…#57875) * Adding flatten spmd segmentation and derivation rules for Paddle * fix bugs * add unit test code * modified: test/auto_parallel/spmd_rules/CMakeLists.txt * modify the code according to the review * modified: paddle/phi/infermeta/spmd_rules/flatten.cc
…#57875) * Adding flatten spmd segmentation and derivation rules for Paddle * fix bugs * add unit test code * modified: test/auto_parallel/spmd_rules/CMakeLists.txt * modify the code according to the review * modified: paddle/phi/infermeta/spmd_rules/flatten.cc
…#57875) * Adding flatten spmd segmentation and derivation rules for Paddle * fix bugs * add unit test code * modified: test/auto_parallel/spmd_rules/CMakeLists.txt * modify the code according to the review * modified: paddle/phi/infermeta/spmd_rules/flatten.cc
PR types
Others
PR changes
Others
Description
为 Paddle 新增 flatten 的 spmd 切分推导规则
#57262