Skip to content

Commit

Permalink
b21 boxing add ctrl_edge (#4770)
Browse files Browse the repository at this point in the history
* b21 boxing add ctrl_edge

* refine

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: cheng cheng <472491134@qq.com>
  • Loading branch information
3 people authored Apr 28, 2021
1 parent 25f9f17 commit a9f70c7
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions oneflow/core/graph/boxing/b21_sub_task_graph_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,17 @@ Maybe<SubTskGphBuilderStatus> B21SubTskGphBuilder::Build(
const int64_t out_parallel_id = 0;
const int64_t nearest_in_parallel_id = SubTskGphBuilderUtil::FindNearestSrcParallelId(
in_parallel_desc, out_parallel_desc, out_parallel_id);
TaskNode* nearest_in_node = sorted_in_tasks.at(nearest_in_parallel_id);
TaskNode* proxy =
ctx->task_graph()->GetProxyNode(nearest_in_node, lbi, out_parallel_desc, out_parallel_id);
sorted_out_tasks->push_back(proxy);
sorted_ctrl_tasks->resize(1);
FOR_RANGE(int64_t, i, 0, in_parallel_desc.parallel_num()) {
TaskNode* in_node = sorted_in_tasks.at(i);
if (i == nearest_in_parallel_id) {
TaskNode* proxy =
ctx->task_graph()->GetProxyNode(in_node, lbi, out_parallel_desc, out_parallel_id);
sorted_out_tasks->push_back(proxy);
} else {
sorted_ctrl_tasks->at(0).push_back(in_node);
}
}
return TRY(BuildSubTskGphBuilderStatus("B21SubTskGphBuilder", ""));
} else {
return Error::BoxingNotSupportedError();
Expand Down

0 comments on commit a9f70c7

Please sign in to comment.