Skip to content

[Dist Dialect] Add MoE-related api in PIR dist dialect #66462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 30, 2024

Conversation

pkuzyc
Copy link
Contributor

@pkuzyc pkuzyc commented Jul 24, 2024

PR Category

Auto Parallel

PR Types

New features

Description

Pcard-67164

Add corresponding dist_op of the MoE api (#63904) in PIR dist dialect, used for MoE model as following:
图片 1

  • local_tensors_from_dtensor: get the tensor list of a dist_tensor on sub-mesh, e.g. for a dist tensor with mesh=[0,1], placements=[Shard(0)], get its sub-mesh list [DistTensor(mesh=[0],placements=[Replicate()]), DistTensor(mesh=[1], placements=[Replicate()])]
  • dtensor_from_local_tensors: the opposite operation of local_tensors_from_dtensor, get the global-mesh dist tensor from sub-mesh dist tensors, e.g. for a sub-mesh list [DistTensor(mesh=[0],placements=[Replicate()]), DistTensor(mesh=[1], placements=[Replicate()])], get the global mesh dist tensor with mesh=[0,1], placements=[Shard(0)].

Copy link

paddle-bot bot commented Jul 24, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -40,5 +40,24 @@ pir::Value reshard(
pir::Value reshard(const pir::Value& x,
const TensorDistAttribute& tensor_dist_attr);

std::vector<pir::Value> local_tensors_from_dist(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个 op 是只在IR 表示层 还是会 给到 执行层 给到执行器去执行?
dist2dense pass 会删除所有 dist dialect 里表示层的信息,给执行器一个 纯 local dense 的program,是否会删除这两个 ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只在表示层,执行的时候会用 share_data 使用当前local的数据。下一个 PR 里会在 remove_other_rank_op_pass 里面加替换操作。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个替换逻辑不应该在 remove_other_rank_op_pass 实现,感觉在 reshard 的 pass 里实现更为合理, 讲一个表示层的 reshard op 解析替换成实际 collective 操作 ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtensor_from_local_tensors 可能也需要 reshard,所以放在remove_other_rank_op_pass里了,也可以试下放到 reshard pass 的最后

Copy link
Contributor

@zhiqiu zhiqiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@xiaoguoguo626807 xiaoguoguo626807 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for backward

@zhiqiu zhiqiu merged commit 8718d78 into PaddlePaddle:develop Jul 30, 2024
30 of 31 checks passed
Lans1ot pushed a commit to Lans1ot/Paddle that referenced this pull request Aug 5, 2024
…66462)

* add two MoE api in distributed dialect

* polish the dist_op and add unit test

* remove simple_net_ep unit test

* remove redundant print

* bug fix, replace platform::errors with phi::errors
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants