Skip to content

Fit PIR AMP for auto_parallel #65892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 22, 2024
Merged

Fit PIR AMP for auto_parallel #65892

merged 19 commits into from
Jul 22, 2024

Conversation

zhiqiu
Copy link
Contributor

@zhiqiu zhiqiu commented Jul 9, 2024

PR Category

Auto Parallel

PR Types

New features

Description

Fit PIR AMP for auto_parallel

Pcard-76459

Copy link

paddle-bot bot commented Jul 9, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@zhiqiu zhiqiu changed the title Amp Fit PIR AMP for auto_parallel Jul 19, 2024
@@ -856,7 +859,8 @@ void SetInplaceOutputCorrectDistAttr(
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
if (dist_tensor->initialized()) {
if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr[i])) {
if (ReshardIsNeededWithPartial(dist_tensor->dist_attr(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个修改会不会造成一些副作用?SetInplaceOutputCorrectDistAttr这个函数似乎在很多算子实现里面调用了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前不考虑partial应该是错误的,这里无论什么case都需要set_dist_attr,只是某些情况需要reshard。

if (auto vec_type = value.type().dyn_cast<pir::VectorType>()) {
for (size_t idx = 0; idx < vec_type.size(); ++idx) {
if (auto dist_type = vec_type[idx].dyn_cast<DistTypeInterface>()) {
meshes.insert(dist_type.process_mesh_attr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果出现非 DistTypeInterface,需要报错?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DistTypeInterface不贡献子mesh即可,当前没有对这类报错,因为只在 check_finite_and_unscale 以及update_loss_scaling中调用这个函数。

<< (mesh.process_ids() != first_mesh.process_ids());
if (mesh.shape() == first_mesh.shape() &&
mesh.dim_names() == first_mesh.dim_names() &&
mesh.process_ids() != first_mesh.process_ids()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉还需要判断 first_mesh 和 mesh 不全部相等(部分相等)情况,如:
first_mesh = [[1,2,],[3,4]]; mesh=[[3,4],[5,6]]
mesh.process_ids() != first_mesh.process_ids()
但也需要报错

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个if条件就是包含这种检查,会报错的。

# re-run apply_mix2dist_pass to dist accumulator.
apply_mix2dist_pass(dist_program)
if self._strategy.amp.enable:
amp_lists = paddle.static.amp.decorator.AutoMixedPrecisionLists(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

推荐封一个AMP函数?不然主体流程代码就不清晰了 lol

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续再优化下,目前的接口还有些问题。

@@ -699,8 +743,17 @@ def _parallel_pir(self, mode):
# collect the communicator created during resolution.
apply_reshard_pass(dist_program)

# print('after reshard', dist_program, flush=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我是特意留下来方便debug的,后续可以升级成 log 。

reduce_ops = {
paddle.base.core.ReduceType.kRedSum: paddle._C_ops.c_allreduce_sum,
paddle.base.core.ReduceType.kRedAvg: paddle._C_ops.c_allreduce_avg,
paddle.base.core.ReduceType.kRedMax: paddle._C_ops.c_allreduce_max,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不一定需要reduce max? reduce sum 也可以判断是否出现 nan inf

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

推导规则中,对bool求sum感觉有点奇怪?

@@ -1081,7 +1144,9 @@ def _initialize(self, mode, init_parameters=True):
# 4. lazy init adaption
# 5. amp init adaption
# 6. vpp init adaption

self.program_helper.init_pir(
self._pir_dist_main_progs[mode], self._place
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个init_pir似乎连续调用了两次?

Copy link
Contributor Author

@zhiqiu zhiqiu Jul 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix it in the next pr, thx.

Copy link
Contributor

@zhangbo9674 zhangbo9674 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approved for print

Copy link
Contributor

@XieYunshen XieYunshen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhiqiu zhiqiu merged commit ae0fe07 into PaddlePaddle:develop Jul 22, 2024
31 checks passed
lixcli pushed a commit to lixcli/Paddle that referenced this pull request Jul 22, 2024
* [Test]Support pir amp for dist

* Refine code

* refine pir dist to_static

* fix bug

* fix partial

* Fix dist engine code

* fit pir grad_scaler with auto_parallel

* use amp strategy

* update ut

* update ut

* fit for amp o1

* revert changes of grad_scaler

* fix ut and refine code

---------

Co-authored-by: 0x45f <wangzhen45@baidu.com>
Co-authored-by: winter-wang <1030748926@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants