-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[Dist Dialect] Add MoE-related api in PIR dist dialect #66462
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -40,5 +40,24 @@ pir::Value reshard( | |||
pir::Value reshard(const pir::Value& x, | |||
const TensorDistAttribute& tensor_dist_attr); | |||
|
|||
std::vector<pir::Value> local_tensors_from_dist( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个 op 是只在IR 表示层 还是会 给到 执行层 给到执行器去执行?
dist2dense pass 会删除所有 dist dialect 里表示层的信息,给执行器一个 纯 local dense 的program,是否会删除这两个 ops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
只在表示层,执行的时候会用 share_data 使用当前local的数据。下一个 PR 里会在 remove_other_rank_op_pass 里面加替换操作。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个替换逻辑不应该在 remove_other_rank_op_pass 实现,感觉在 reshard 的 pass 里实现更为合理, 讲一个表示层的 reshard op 解析替换成实际 collective 操作 ops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtensor_from_local_tensors 可能也需要 reshard,所以放在remove_other_rank_op_pass里了,也可以试下放到 reshard pass 的最后
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for backward
…66462) * add two MoE api in distributed dialect * polish the dist_op and add unit test * remove simple_net_ep unit test * remove redundant print * bug fix, replace platform::errors with phi::errors
PR Category
Auto Parallel
PR Types
New features
Description
Pcard-67164
Add corresponding dist_op of the MoE api (#63904) in PIR dist dialect, used for MoE model as following:

local_tensors_from_dtensor
: get the tensor list of a dist_tensor on sub-mesh, e.g. for a dist tensor with mesh=[0,1], placements=[Shard(0)], get its sub-mesh list [DistTensor(mesh=[0],placements=[Replicate()])
,DistTensor(mesh=[1], placements=[Replicate()])
]dtensor_from_local_tensors
: the opposite operation oflocal_tensors_from_dtensor
, get the global-mesh dist tensor from sub-mesh dist tensors, e.g. for a sub-mesh list [DistTensor(mesh=[0],placements=[Replicate()])
,DistTensor(mesh=[1], placements=[Replicate()])
], get the global mesh dist tensor with mesh=[0,1], placements=[Shard(0)].