-
Notifications
You must be signed in to change notification settings - Fork 833
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Plan rank compiler #10141
Merged
Merged
Plan rank compiler #10141
Changes from all commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
62dee17
add task to/from proto
strint db84632
Update boxing_task_graph.proto
strint 49c8d18
Update task_edge.proto
strint 041a4b3
Update task_graph_rebuild_ctx.cpp
strint 3e08944
Update task_graph_rebuild_ctx.h
strint e0cf92b
Update transport_task_node.cpp
strint 008239e
support infer desc choose method
strint be2987d
refine comment
strint 9acbc79
rm useless
strint 34b3133
add comsume fake regst
strint acff92c
fix typo
strint 17203d0
add task factory to create new task node
strint 4377368
Merge branch 'sep0_task_proto' into sep2_custom_blobdesc_infer
strint 88f4297
add infer from ndsbp
strint 0e7b8ed
Merge branch 'sep2_custom_blobdesc_infer' into sep3_fake_regst
strint 266c388
rm useless
strint 9f952bd
Merge branch 'sep0_task_proto' into sep2_custom_blobdesc_infer
strint a9ad100
Merge branch 'sep2_custom_blobdesc_infer' into sep3_fake_regst
strint d7b7594
add rank compiler
strint 09154d0
add rank compiler
strint 410c71f
merge master
strint e9a20de
auto format by CI
oneflow-ci-bot 57499b7
fix merge
strint d169bdd
Merge branch 'sep4_rank_task_graph' of https://github.com/Oneflow-Inc…
strint c956597
Merge branch 'master' into sep4_rank_task_graph
strint c2039df
fix licence
strint 09293c4
Merge branch 'sep4_rank_task_graph' of https://github.com/Oneflow-Inc…
strint 4586502
move CreateOpAttributeRef
strint 45c1ca4
Merge branch 'master' into sep4_rank_task_graph
strint bded88f
fix NeedBoxing for NDSBP
strint af8303e
Merge branch 'sep4_rank_task_graph' of https://github.com/Oneflow-Inc…
strint adb1b08
address review
strint 429aa14
Merge branch 'master' into sep4_rank_task_graph
strint a916407
fix static check
strint File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,11 +17,37 @@ limitations under the License. | |
#include "oneflow/core/job/job_builder.h" | ||
#include "oneflow/core/job/local_sig_infer_hint.h" | ||
#include "oneflow/core/job/lazy_mode.h" | ||
#include "oneflow/core/common/container_util.h" | ||
#include "oneflow/core/persistence/tee_persistent_log_stream.h" | ||
#include "oneflow/core/auto_parallel/algorithm_util.h" | ||
#include "oneflow/core/framework/nd_sbp.h" | ||
#include "oneflow/core/framework/sbp_infer_util.h" | ||
|
||
namespace oneflow { | ||
|
||
bool OpEdge::NeedBoxing() const { | ||
if (src_node()->parallel_desc_sym() != dst_node()->parallel_desc_sym()) { return true; } | ||
if (src_node()->parallel_desc().parallel_num() == 1) { return false; } | ||
for (const auto& lbi : *lbis_) { | ||
Shape src_reduced_hierarchy; | ||
Shape dst_reduced_hierarchy; | ||
NdSbp src_reduced_nd_sbp; | ||
NdSbp dst_reduced_nd_sbp; | ||
|
||
InOutParallelDimReduce(*src_node()->parallel_desc().hierarchy(), | ||
*dst_node()->parallel_desc().hierarchy(), src_node()->NdSbp4Lbi(lbi), | ||
dst_node()->NdSbp4Lbi(lbi), &src_reduced_hierarchy, | ||
&dst_reduced_hierarchy, &src_reduced_nd_sbp, &dst_reduced_nd_sbp, | ||
src_node()->LogicalBlobDesc4Lbi(lbi).shape()); | ||
if (src_reduced_hierarchy != dst_reduced_hierarchy | ||
|| src_reduced_nd_sbp != dst_reduced_nd_sbp) { | ||
// Not one to one | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
std::string OpEdge::VisualStr() const { | ||
std::string str; | ||
int32_t idx = 0; | ||
|
@@ -54,12 +80,11 @@ const NdSbp& OpNode::NdSbp4Lbi(const LogicalBlobId& lbi) const { | |
return it->second; | ||
} | ||
|
||
OpNode::OpNode(const std::shared_ptr<const ParallelDesc>& parallel_desc, | ||
const OperatorConf& op_conf) | ||
OpNode::OpNode(Symbol<ParallelDesc> parallel_desc, const OperatorConf& op_conf) | ||
: parallel_desc_(parallel_desc), | ||
op_(CHECK_JUST(ConstructOp(op_conf, parallel_desc->device_type()))), | ||
ibns_(op_->input_bns().begin(), op_->input_bns().end()) { | ||
CHECK_JUST(op_->FillOpParallelDesc(parallel_desc)); | ||
CHECK_JUST(op_->FillOpParallelDesc(parallel_desc.shared_from_symbol())); | ||
} | ||
|
||
std::string OpNode::VisualStr() const { | ||
|
@@ -194,16 +219,14 @@ void OpGraph::CheckIsDAG() const { | |
|
||
namespace { | ||
|
||
std::function<std::shared_ptr<const ParallelDesc>(const std::string&)> | ||
MakeGetterParallelDesc4OpName(const Job& job) { | ||
std::function<Symbol<ParallelDesc>(const std::string&)> MakeGetterParallelDesc4OpName( | ||
const Job& job) { | ||
const Placement& placement = job.placement(); | ||
auto op_name2parallel_desc = | ||
std::make_shared<HashMap<std::string, std::shared_ptr<const ParallelDesc>>>(); | ||
auto op_name2parallel_desc = std::make_shared<HashMap<std::string, Symbol<ParallelDesc>>>(); | ||
op_name2parallel_desc->reserve(job.net().op_size()); | ||
for (const auto& placement_group : placement.placement_group()) { | ||
const ParallelConf& parallel_conf = placement_group.parallel_conf(); | ||
std::shared_ptr<const ParallelDesc> parallel_desc = | ||
std::make_shared<const ParallelDesc>(parallel_conf); | ||
Symbol<ParallelDesc> parallel_desc = SymbolOf(ParallelDesc(parallel_conf)); | ||
for (const std::string& op_name : placement_group.op_set().op_name()) { | ||
CHECK(op_name2parallel_desc->emplace(op_name, parallel_desc).second) | ||
<< "op_name: " << op_name; | ||
|
@@ -566,6 +589,11 @@ Maybe<void> OpGraph::ForEachOpNode(const std::function<Maybe<void>(const OpNode& | |
return Maybe<void>::Ok(); | ||
} | ||
|
||
std::function<bool(const OpNode* src, const OpNode* dst)> OpGraph::CreatePredicatorIsReachable() | ||
const { | ||
return MakePredicatorIsReachable(); | ||
} | ||
|
||
// Print the graph with SBP in order | ||
void OpGraph::PrintSBPGraphDebugInfo() const { | ||
// test debug | ||
|
@@ -622,4 +650,17 @@ void OpGraph::PrintSBPGraphDebugInfo() const { | |
} | ||
} | ||
|
||
OpGraphSingletonGuard::OpGraphSingletonGuard(const Job& job) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 提供一个 RAII 风格的 OpGraph |
||
// new Singleton<OpGraph> and set log configs. | ||
Singleton<OpGraph>::New(job); | ||
const JobDesc& job_desc = GlobalJobDesc(); | ||
if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) { | ||
TeePersistentLogStream::Create(StrCat("optimized_job", job_desc.job_id()))->Write(job); | ||
Singleton<OpGraph>::Get()->ToDotWithFilePath( | ||
"optimized_dlnet_" + std::to_string(job_desc.job_id()) + "_op_graph.dot"); | ||
} | ||
} | ||
|
||
OpGraphSingletonGuard::~OpGraphSingletonGuard() { Singleton<OpGraph>::Delete(); } | ||
|
||
} // namespace oneflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个在什么情况下会是 nullptr 呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在之前的没有 Maybe 的基础上改的,估计是为了严谨。因为默认情况下,edge 初始化时的 src_node 和 dst_node 都是 nullptr