Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RankTaskGraph #9108

Closed
wants to merge 41 commits into from
Closed

RankTaskGraph #9108

wants to merge 41 commits into from

Conversation

lixinqi
Copy link
Contributor

@lixinqi lixinqi commented Sep 19, 2022

将TaskGraph的逻辑拆解成BoxingTaskGraph和RankTaskGraph。BoxingTaskGraph负责构建boxing相关的task graph子图,然后序列化到BoxingTaskGraphProto。RankTaskGraph负责两点:1)构建指定rank的CompTaskNode;2)从BoxingTaskGraphProto恢复属于boxing部分的子图;
分布式编译的大体过程将会是:

  1. 在main线程(或master进程)上由OpGraph构建BoxingTaskGraph,并序列化成BoxingTaskGraphProto;
  2. 在线程池里的各个worker线程(或worker进程)上由OpGraph/BoxingTaskGraphProto/rank等信息构建属于该rank的RankTaskGraph,然后生成该rank的plan。

本pr实现的是分离编译的中间状态版本:即BoxingTaskGraph在main线程上构建,而RankTaskGraph在线程池里构建。
后续pr再实现彻底的分离编译,即BoxingTaskGraph在master进程上构建,而RankTaskGraph在worker进程上构建。

@lixinqi lixinqi requested a review from strint as a code owner September 19, 2022 10:25
@lixinqi
Copy link
Contributor Author

lixinqi commented Sep 23, 2022

提供了环境变量供切换:

  1. ONEFLOW_LAZY_COMPILE_MODE=naive 旧版编译方式,全rank编译。
  2. ONEFLOW_LAZY_COMPILE_MODE=rank_per_thread 多线程分离编译,每个rank放在独立的线程里。
  3. ONEFLOW_LAZY_COMPILE_MODE=rank_per_iter 单线程分离编译,每个rank放在main线程的每次循环里。

如果多线程分离编译遇到bug,请回到单线程分离编译再跑一次。

@@ -709,12 +727,19 @@ void TaskGraph::EnableInplaceMemSharing(
const std::function<bool(const std::string&, const std::string&)>&
IsOpNameDataOrCtrlReachable) {
ForEachGpuDeviceNodes([&](const HashSet<TaskNode*>& dev_nodes) {
InplaceObasInfo safe_inplace_obas_info;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该逻辑已迁移到

void TaskGraph::EnableInplaceMemSharing(
    const HashSet<TaskNode*>& dev_nodes,
    const std::function<bool(const std::string&, const std::string&)>&
        IsOpNameDataOrCtrlReachable);

void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_nodes,
const std::vector<CompTaskNode*>& dst_task_nodes) {
CHECK_EQ(src_task_nodes.size(), dst_task_nodes.size());
FOR_RANGE(int32_t, i, 0, src_task_nodes.size()) {
std::string regst_desc_name;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该逻辑已移至

void TaskGraph::ConnectCtrlEdge(CompTaskNode* src_task_node, CompTaskNode* dst_task_node);

@@ -94,9 +104,6 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> {
IsOpNameDataOrCtrlReachable) const;
void SetTaskRegstInplaceInfo(const InplaceObasInfo& obas_info,
const HashSet<TaskNode*>& dev_nodes) const;
void ForEachGpuDeviceNodes(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该函数并没有删除,而是变成了public

@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

@github-actions
Copy link
Contributor

Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally.

std::vector<Plan> plans(GlobalProcessCtx::WorldSize());
JUST(OpGraph::WithSingleton(&job_, [&]() -> Maybe<void> {
Singleton<OpGraph>::Get()->UpdateCachedPredicatorIsReachable();
auto boxing_task_graph = JUST(BoxingTaskGraph::New());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BoxingTaskGraph相较于GlobalTaskGraph而言,它只包含骨干部分:1)boxing部分,所编译的TaskNode全是TransportTaskNode;2)与boxing相关的上下游ComputeTaskNode;

补集GlobalTaskGraph - BoxingTaskGraph里边的每一个TaskNode都是可以用Op + Sbp切割得到。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BoxingTaskGraph是所有rank的共识。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RankTaskGraph必须包含BoxingTaskGraph中本rank相关部分的根源是:

  1. 每个rank在编译时要独立推导shape, dtype。
  2. 每个rank在编译时要知晓上下游。

auto boxing_task_graph = JUST(BoxingTaskGraph::New());
// reachable collective boxing task pairs,
std::vector<HashSet<std::pair<int64_t /*src task_id*/, int64_t /*dst task_id*/>>>
reachable_cb_pairs{GlobalProcessCtx::WorldSize()};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

收集每个rank上 collective boxing task的可达关系。

bool IsComputTaskNodeDutyRank(int64_t current_rank, const ParallelDesc& parallel_desc,
int64_t task_node_rank) {
if (current_rank == 0) {
// make sure master knows at least one op_node.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at least one compute task node.

reachable_cb_pairs{GlobalProcessCtx::WorldSize()};
Loop(GlobalProcessCtx::WorldSize(), [&](size_t i) {
auto boxing_task_graph_proto = std::make_shared<BoxingTaskGraphProto>();
auto PickTaskNode = [&]() -> std::function<bool(TaskNode*)> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TransportTaskNode

@@ -44,7 +46,7 @@ void AccCompTaskNode::BuildExecGphAndRegst() {
exec_node->BindBnWithRegst(op()->SoleIbn(), in_regst);
out_regst->AddLbi(op()->BnInOp2Lbi(op()->SoleObn()));
exec_node->BindBnWithRegst(op()->SoleObn(), out_regst);
exec_node->InferBlobDescs(parallel_ctx());
(exec_node->*InferBlobDescs())(parallel_ctx());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

成诚:名字应该改成GetExecNodeInferBlobDescsMethod()

sole_regst_desc = regst_desc;
});
auto* predefined = &regst_desc2predefined_regst_desc_id[sole_regst_desc];
*predefined = std::max(*predefined, comm_task_node->candidate_in_regst_desc_id());
Copy link
Contributor Author

@lixinqi lixinqi Nov 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处解决跨rank编译regst_desc_id同步问题。
首先,只有CopyCommNetTaskNode的input regst_desc_id才涉及到跨rank。我们仅需要对这里做处理就行。我们不需要对跨rank的ctrl regst desc id做处理的原因是它们已经保证了同步,它们在构建boxing task graph的过程中就会被构建,也就会随着boxing task graph的分发而同步到各处。

本次修复中,CopyCommNetTaskNode会新持有一个int64_t candidate_in_regst_desc_id字段,该字段会用于初始化CopyCommNetTaskNode上游TaskNode的produced in regst_desc_id。candidate_in_regst_desc_id字段名中带一个candidate字眼是因为上游TaskNode可能有多个下游的CopyCommNetTaskNode节点,所以上游TaskNode的produced regst_desc最后会选一个最大的candidate_in_regst_desc_id作为最终的regst_desc_id。

CopyCommNetTaskNode::candidate_in_regst_desc_id字段会随着boxing task graph的分发而全局同步,自然而然,同一个上游TaskNode的produced regst_desc_id也会由于max(candidate_in_regst_desc_id)的同步而得到同步。


} // namespace

std::unique_ptr<RegstDescIdProvider> NewRegstDescIdProvider() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RegstDesc类上不再直接持有int64_t regst_desc_id_,而是持有一个多态的RegstDescIdProvider regst_desc_id_provider_字段。它可以是ConstRegstDescIdProvider,用于对齐Naive情况,也可以是LazyInitRegstDescIdProvider,用于rank_per_iter/rank_per_thread/...等情况,其中regst_desc_id的设置会考虑producer task_node下游节点的情况。

@strint strint changed the base branch from master to rank_task_graph_test_passed December 13, 2022 10:12
@strint strint changed the base branch from rank_task_graph_test_passed to master December 13, 2022 10:59
@mergify mergify bot mentioned this pull request Dec 13, 2022
strint and others added 3 commits December 13, 2022 19:17
To be fixed distributed test.

---------

Co-authored-by: lixinqi <lixinqi0703106@163.com>
Co-authored-by: cheng cheng <472491134@qq.com>
This was referenced Feb 28, 2023
@strint strint closed this Apr 13, 2023
@strint strint deleted the rank_task_graph branch April 13, 2023 08:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants