-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RankTaskGraph #9108
RankTaskGraph #9108
Conversation
提供了环境变量供切换:
如果多线程分离编译遇到bug,请回到单线程分离编译再跑一次。 |
@@ -709,12 +727,19 @@ void TaskGraph::EnableInplaceMemSharing( | |||
const std::function<bool(const std::string&, const std::string&)>& | |||
IsOpNameDataOrCtrlReachable) { | |||
ForEachGpuDeviceNodes([&](const HashSet<TaskNode*>& dev_nodes) { | |||
InplaceObasInfo safe_inplace_obas_info; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
该逻辑已迁移到
void TaskGraph::EnableInplaceMemSharing(
const HashSet<TaskNode*>& dev_nodes,
const std::function<bool(const std::string&, const std::string&)>&
IsOpNameDataOrCtrlReachable);
void TaskGraph::ConnectCtrlEdges(const std::vector<CompTaskNode*>& src_task_nodes, | ||
const std::vector<CompTaskNode*>& dst_task_nodes) { | ||
CHECK_EQ(src_task_nodes.size(), dst_task_nodes.size()); | ||
FOR_RANGE(int32_t, i, 0, src_task_nodes.size()) { | ||
std::string regst_desc_name; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
该逻辑已移至
void TaskGraph::ConnectCtrlEdge(CompTaskNode* src_task_node, CompTaskNode* dst_task_node);
@@ -94,9 +104,6 @@ class TaskGraph final : public Graph<TaskNode, TaskEdge> { | |||
IsOpNameDataOrCtrlReachable) const; | |||
void SetTaskRegstInplaceInfo(const InplaceObasInfo& obas_info, | |||
const HashSet<TaskNode*>& dev_nodes) const; | |||
void ForEachGpuDeviceNodes( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
该函数并没有删除,而是变成了public
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
oneflow/core/framework/nn_graph.cpp
Outdated
std::vector<Plan> plans(GlobalProcessCtx::WorldSize()); | ||
JUST(OpGraph::WithSingleton(&job_, [&]() -> Maybe<void> { | ||
Singleton<OpGraph>::Get()->UpdateCachedPredicatorIsReachable(); | ||
auto boxing_task_graph = JUST(BoxingTaskGraph::New()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BoxingTaskGraph相较于GlobalTaskGraph而言,它只包含骨干部分:1)boxing部分,所编译的TaskNode全是TransportTaskNode;2)与boxing相关的上下游ComputeTaskNode;
补集GlobalTaskGraph - BoxingTaskGraph里边的每一个TaskNode都是可以用Op + Sbp切割得到。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BoxingTaskGraph是所有rank的共识。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RankTaskGraph必须包含BoxingTaskGraph中本rank相关部分的根源是:
- 每个rank在编译时要独立推导shape, dtype。
- 每个rank在编译时要知晓上下游。
auto boxing_task_graph = JUST(BoxingTaskGraph::New()); | ||
// reachable collective boxing task pairs, | ||
std::vector<HashSet<std::pair<int64_t /*src task_id*/, int64_t /*dst task_id*/>>> | ||
reachable_cb_pairs{GlobalProcessCtx::WorldSize()}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
收集每个rank上 collective boxing task的可达关系。
bool IsComputTaskNodeDutyRank(int64_t current_rank, const ParallelDesc& parallel_desc, | ||
int64_t task_node_rank) { | ||
if (current_rank == 0) { | ||
// make sure master knows at least one op_node. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at least one compute task node.
reachable_cb_pairs{GlobalProcessCtx::WorldSize()}; | ||
Loop(GlobalProcessCtx::WorldSize(), [&](size_t i) { | ||
auto boxing_task_graph_proto = std::make_shared<BoxingTaskGraphProto>(); | ||
auto PickTaskNode = [&]() -> std::function<bool(TaskNode*)> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TransportTaskNode
@@ -44,7 +46,7 @@ void AccCompTaskNode::BuildExecGphAndRegst() { | |||
exec_node->BindBnWithRegst(op()->SoleIbn(), in_regst); | |||
out_regst->AddLbi(op()->BnInOp2Lbi(op()->SoleObn())); | |||
exec_node->BindBnWithRegst(op()->SoleObn(), out_regst); | |||
exec_node->InferBlobDescs(parallel_ctx()); | |||
(exec_node->*InferBlobDescs())(parallel_ctx()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
成诚:名字应该改成GetExecNodeInferBlobDescsMethod()
oneflow/core/job/rank_compiler.cpp
Outdated
sole_regst_desc = regst_desc; | ||
}); | ||
auto* predefined = ®st_desc2predefined_regst_desc_id[sole_regst_desc]; | ||
*predefined = std::max(*predefined, comm_task_node->candidate_in_regst_desc_id()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
此处解决跨rank编译regst_desc_id同步问题。
首先,只有CopyCommNetTaskNode的input regst_desc_id才涉及到跨rank。我们仅需要对这里做处理就行。我们不需要对跨rank的ctrl regst desc id做处理的原因是它们已经保证了同步,它们在构建boxing task graph的过程中就会被构建,也就会随着boxing task graph的分发而同步到各处。
本次修复中,CopyCommNetTaskNode会新持有一个int64_t candidate_in_regst_desc_id字段,该字段会用于初始化CopyCommNetTaskNode上游TaskNode的produced in regst_desc_id。candidate_in_regst_desc_id字段名中带一个candidate字眼是因为上游TaskNode可能有多个下游的CopyCommNetTaskNode节点,所以上游TaskNode的produced regst_desc最后会选一个最大的candidate_in_regst_desc_id作为最终的regst_desc_id。
CopyCommNetTaskNode::candidate_in_regst_desc_id字段会随着boxing task graph的分发而全局同步,自然而然,同一个上游TaskNode的produced regst_desc_id也会由于max(candidate_in_regst_desc_id)的同步而得到同步。
|
||
} // namespace | ||
|
||
std::unique_ptr<RegstDescIdProvider> NewRegstDescIdProvider() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RegstDesc类上不再直接持有int64_t regst_desc_id_,而是持有一个多态的RegstDescIdProvider regst_desc_id_provider_字段。它可以是ConstRegstDescIdProvider,用于对齐Naive情况,也可以是LazyInitRegstDescIdProvider,用于rank_per_iter/rank_per_thread/...等情况,其中regst_desc_id的设置会考虑producer task_node下游节点的情况。
18e3bac
to
3910af6
Compare
To be fixed distributed test. --------- Co-authored-by: lixinqi <lixinqi0703106@163.com> Co-authored-by: cheng cheng <472491134@qq.com>
将TaskGraph的逻辑拆解成BoxingTaskGraph和RankTaskGraph。BoxingTaskGraph负责构建boxing相关的task graph子图,然后序列化到BoxingTaskGraphProto。RankTaskGraph负责两点:1)构建指定rank的CompTaskNode;2)从BoxingTaskGraphProto恢复属于boxing部分的子图;
分布式编译的大体过程将会是:
本pr实现的是分离编译的中间状态版本:即BoxingTaskGraph在main线程上构建,而RankTaskGraph在线程池里构建。
后续pr再实现彻底的分离编译,即BoxingTaskGraph在master进程上构建,而RankTaskGraph在worker进程上构建。