From a9f70c799acda2086978cc3e225b4753b35ec52d Mon Sep 17 00:00:00 2001 From: guo ran <360112263@qq.com> Date: Wed, 28 Apr 2021 21:33:20 +0800 Subject: [PATCH] b21 boxing add ctrl_edge (#4770) * b21 boxing add ctrl_edge * refine Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> Co-authored-by: cheng cheng <472491134@qq.com> --- .../graph/boxing/b21_sub_task_graph_builder.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/oneflow/core/graph/boxing/b21_sub_task_graph_builder.cpp b/oneflow/core/graph/boxing/b21_sub_task_graph_builder.cpp index 1cc3eefcde9..460fec5269d 100644 --- a/oneflow/core/graph/boxing/b21_sub_task_graph_builder.cpp +++ b/oneflow/core/graph/boxing/b21_sub_task_graph_builder.cpp @@ -30,10 +30,17 @@ Maybe B21SubTskGphBuilder::Build( const int64_t out_parallel_id = 0; const int64_t nearest_in_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId( in_parallel_desc, out_parallel_desc, out_parallel_id); - TaskNode* nearest_in_node = sorted_in_tasks.at(nearest_in_parallel_id); - TaskNode* proxy = - ctx->task_graph()->GetProxyNode(nearest_in_node, lbi, out_parallel_desc, out_parallel_id); - sorted_out_tasks->push_back(proxy); + sorted_ctrl_tasks->resize(1); + FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) { + TaskNode* in_node = sorted_in_tasks.at(i); + if (i == nearest_in_parallel_id) { + TaskNode* proxy = + ctx->task_graph()->GetProxyNode(in_node, lbi, out_parallel_desc, out_parallel_id); + sorted_out_tasks->push_back(proxy); + } else { + sorted_ctrl_tasks->at(0).push_back(in_node); + } + } return TRY(BuildSubTskGphBuilderStatus("B21SubTskGphBuilder", "")); } else { return Error::BoxingNotSupportedError();