Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plan rank compiler #10141

Merged
merged 34 commits into from
Jun 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
62dee17
add task to/from proto
strint Apr 12, 2023
db84632
Update boxing_task_graph.proto
strint Apr 12, 2023
49c8d18
Update task_edge.proto
strint Apr 12, 2023
041a4b3
Update task_graph_rebuild_ctx.cpp
strint Apr 12, 2023
3e08944
Update task_graph_rebuild_ctx.h
strint Apr 12, 2023
e0cf92b
Update transport_task_node.cpp
strint Apr 12, 2023
008239e
support infer desc choose method
strint Apr 13, 2023
be2987d
refine comment
strint Apr 13, 2023
9acbc79
rm useless
strint Apr 14, 2023
34b3133
add comsume fake regst
strint Apr 14, 2023
acff92c
fix typo
strint Apr 14, 2023
17203d0
add task factory to create new task node
strint Apr 14, 2023
4377368
Merge branch 'sep0_task_proto' into sep2_custom_blobdesc_infer
strint Apr 14, 2023
88f4297
add infer from ndsbp
strint Apr 14, 2023
0e7b8ed
Merge branch 'sep2_custom_blobdesc_infer' into sep3_fake_regst
strint Apr 14, 2023
266c388
rm useless
strint Apr 14, 2023
9f952bd
Merge branch 'sep0_task_proto' into sep2_custom_blobdesc_infer
strint Apr 14, 2023
a9ad100
Merge branch 'sep2_custom_blobdesc_infer' into sep3_fake_regst
strint Apr 14, 2023
d7b7594
add rank compiler
strint Apr 14, 2023
09154d0
add rank compiler
strint Apr 15, 2023
410c71f
merge master
strint May 11, 2023
e9a20de
auto format by CI
oneflow-ci-bot May 11, 2023
57499b7
fix merge
strint May 12, 2023
d169bdd
Merge branch 'sep4_rank_task_graph' of https://github.com/Oneflow-Inc…
strint May 12, 2023
c956597
Merge branch 'master' into sep4_rank_task_graph
strint May 12, 2023
c2039df
fix licence
strint May 12, 2023
09293c4
Merge branch 'sep4_rank_task_graph' of https://github.com/Oneflow-Inc…
strint May 12, 2023
4586502
move CreateOpAttributeRef
strint May 20, 2023
45c1ca4
Merge branch 'master' into sep4_rank_task_graph
strint May 21, 2023
bded88f
fix NeedBoxing for NDSBP
strint May 24, 2023
af8303e
Merge branch 'sep4_rank_task_graph' of https://github.com/Oneflow-Inc…
strint May 24, 2023
adb1b08
address review
strint May 27, 2023
429aa14
Merge branch 'master' into sep4_rank_task_graph
strint Jun 3, 2023
a916407
fix static check
strint Jun 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions oneflow/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Graph {
std::function<Maybe<void>(NodeType*)> NodeHandler) const;
void ReverseTopoForEachNode(std::function<void(NodeType*)> NodeHandler) const;
void ForEachEdge(std::function<void(EdgeType*)> EdgeHandler) const;
Maybe<void> MaybeForEachEdge(std::function<Maybe<void>(EdgeType*)> EdgeHandler) const;

void SortedTopoForEachNode(std::function<bool(const EdgeType* lhs, const EdgeType* rhs)> LessThan,
std::function<void(NodeType*)> NodeHandler) const;
Expand Down Expand Up @@ -292,6 +293,16 @@ void Graph<NodeType, EdgeType>::ForEachEdge(std::function<void(EdgeType*)> EdgeH
}
}

template<typename NodeType, typename EdgeType>
Maybe<void> Graph<NodeType, EdgeType>::MaybeForEachEdge(
std::function<Maybe<void>(EdgeType*)> EdgeHandler) const {
for (auto& x : edges_) {
if (x->src_node() == nullptr && x->dst_node() == nullptr) { continue; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个在什么情况下会是 nullptr 呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个在什么情况下会是 nullptr 呢

在之前的没有 Maybe 的基础上改的,估计是为了严谨。因为默认情况下,edge 初始化时的 src_node 和 dst_node 都是 nullptr

template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ForEachEdge(std::function<void(EdgeType*)> EdgeHandler) const {
  for (auto& x : edges_) {
    if (x->src_node() == nullptr && x->dst_node() == nullptr) { continue; }
    EdgeHandler(x.get());
  }
}

JUST(EdgeHandler(x.get()));
}
return Maybe<void>::Ok();
}

template<typename NodeType, typename EdgeType>
NodeType* Graph<NodeType, EdgeType>::SoleNode() const {
CHECK_EQ(nodes_.size(), 1);
Expand Down
59 changes: 50 additions & 9 deletions oneflow/core/graph/op_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,37 @@ limitations under the License.
#include "oneflow/core/job/job_builder.h"
#include "oneflow/core/job/local_sig_infer_hint.h"
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/persistence/tee_persistent_log_stream.h"
#include "oneflow/core/auto_parallel/algorithm_util.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/framework/sbp_infer_util.h"

namespace oneflow {

bool OpEdge::NeedBoxing() const {
if (src_node()->parallel_desc_sym() != dst_node()->parallel_desc_sym()) { return true; }
if (src_node()->parallel_desc().parallel_num() == 1) { return false; }
for (const auto& lbi : *lbis_) {
Shape src_reduced_hierarchy;
Shape dst_reduced_hierarchy;
NdSbp src_reduced_nd_sbp;
NdSbp dst_reduced_nd_sbp;

InOutParallelDimReduce(*src_node()->parallel_desc().hierarchy(),
*dst_node()->parallel_desc().hierarchy(), src_node()->NdSbp4Lbi(lbi),
dst_node()->NdSbp4Lbi(lbi), &src_reduced_hierarchy,
&dst_reduced_hierarchy, &src_reduced_nd_sbp, &dst_reduced_nd_sbp,
src_node()->LogicalBlobDesc4Lbi(lbi).shape());
if (src_reduced_hierarchy != dst_reduced_hierarchy
|| src_reduced_nd_sbp != dst_reduced_nd_sbp) {
// Not one to one
return true;
}
}
return false;
}

std::string OpEdge::VisualStr() const {
std::string str;
int32_t idx = 0;
Expand Down Expand Up @@ -54,12 +80,11 @@ const NdSbp& OpNode::NdSbp4Lbi(const LogicalBlobId& lbi) const {
return it->second;
}

OpNode::OpNode(const std::shared_ptr<const ParallelDesc>& parallel_desc,
const OperatorConf& op_conf)
OpNode::OpNode(Symbol<ParallelDesc> parallel_desc, const OperatorConf& op_conf)
: parallel_desc_(parallel_desc),
op_(CHECK_JUST(ConstructOp(op_conf, parallel_desc->device_type()))),
ibns_(op_->input_bns().begin(), op_->input_bns().end()) {
CHECK_JUST(op_->FillOpParallelDesc(parallel_desc));
CHECK_JUST(op_->FillOpParallelDesc(parallel_desc.shared_from_symbol()));
}

std::string OpNode::VisualStr() const {
Expand Down Expand Up @@ -194,16 +219,14 @@ void OpGraph::CheckIsDAG() const {

namespace {

std::function<std::shared_ptr<const ParallelDesc>(const std::string&)>
MakeGetterParallelDesc4OpName(const Job& job) {
std::function<Symbol<ParallelDesc>(const std::string&)> MakeGetterParallelDesc4OpName(
const Job& job) {
const Placement& placement = job.placement();
auto op_name2parallel_desc =
std::make_shared<HashMap<std::string, std::shared_ptr<const ParallelDesc>>>();
auto op_name2parallel_desc = std::make_shared<HashMap<std::string, Symbol<ParallelDesc>>>();
op_name2parallel_desc->reserve(job.net().op_size());
for (const auto& placement_group : placement.placement_group()) {
const ParallelConf& parallel_conf = placement_group.parallel_conf();
std::shared_ptr<const ParallelDesc> parallel_desc =
std::make_shared<const ParallelDesc>(parallel_conf);
Symbol<ParallelDesc> parallel_desc = SymbolOf(ParallelDesc(parallel_conf));
for (const std::string& op_name : placement_group.op_set().op_name()) {
CHECK(op_name2parallel_desc->emplace(op_name, parallel_desc).second)
<< "op_name: " << op_name;
Expand Down Expand Up @@ -566,6 +589,11 @@ Maybe<void> OpGraph::ForEachOpNode(const std::function<Maybe<void>(const OpNode&
return Maybe<void>::Ok();
}

std::function<bool(const OpNode* src, const OpNode* dst)> OpGraph::CreatePredicatorIsReachable()
const {
return MakePredicatorIsReachable();
}

// Print the graph with SBP in order
void OpGraph::PrintSBPGraphDebugInfo() const {
// test debug
Expand Down Expand Up @@ -622,4 +650,17 @@ void OpGraph::PrintSBPGraphDebugInfo() const {
}
}

OpGraphSingletonGuard::OpGraphSingletonGuard(const Job& job) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

提供一个 RAII 风格的 OpGraph

// new Singleton<OpGraph> and set log configs.
Singleton<OpGraph>::New(job);
const JobDesc& job_desc = GlobalJobDesc();
if (Singleton<ResourceDesc, ForSession>::Get()->enable_debug_mode()) {
TeePersistentLogStream::Create(StrCat("optimized_job", job_desc.job_id()))->Write(job);
Singleton<OpGraph>::Get()->ToDotWithFilePath(
"optimized_dlnet_" + std::to_string(job_desc.job_id()) + "_op_graph.dot");
}
}

OpGraphSingletonGuard::~OpGraphSingletonGuard() { Singleton<OpGraph>::Delete(); }

} // namespace oneflow
16 changes: 13 additions & 3 deletions oneflow/core/graph/op_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ class OpGraph;
class OpNode final : public Node<OpNode, OpEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(OpNode);
explicit OpNode(const std::shared_ptr<const ParallelDesc>& parallel_desc,
const OperatorConf& op_conf);
explicit OpNode(Symbol<ParallelDesc> parallel_desc, const OperatorConf& op_conf);
~OpNode() = default;

// Getters
bool IsTimeShapeIdentity() const;
const Operator& op() const { return *op_; }
std::shared_ptr<const Operator> shared_op() const { return op_; }
const ParallelDesc& parallel_desc() const { return *parallel_desc_; }
Symbol<ParallelDesc> parallel_desc_sym() const { return parallel_desc_; }
const SbpSignature& sbp_signature() const { return *CHECK_JUST(op().sbp_signature()); }
const NdSbpSignature& nd_sbp_signature() const { return *CHECK_JUST(op().nd_sbp_signature()); }
const SbpParallel& SbpParallel4Lbi(const LogicalBlobId& lbi) const;
Expand All @@ -67,7 +67,7 @@ class OpNode final : public Node<OpNode, OpEdge> {
void InitLbi2SourceNode();
void InitLbi2NdSbp();

std::shared_ptr<const ParallelDesc> parallel_desc_;
Symbol<ParallelDesc> parallel_desc_;
std::shared_ptr<Operator> op_;
HashSet<std::string> ibns_;
HashMap<LogicalBlobId, OpNode*> lbi2source_node_;
Expand All @@ -88,6 +88,8 @@ class OpEdge final : public Edge<OpNode, OpEdge> {
const std::vector<LogicalBlobId>& lbis() const { return *lbis_; }
const HashMap<LogicalBlobId, std::string>& lbi2obn() const { return *lbi2obn_; }
const HashMap<LogicalBlobId, std::vector<std::string>>& lbi2ibns() const { return *lbi2ibns_; }

bool NeedBoxing() const;
std::string VisualStr() const override;

private:
Expand Down Expand Up @@ -130,6 +132,7 @@ class OpGraph final : public Graph<OpNode, OpEdge> {

Maybe<void> Init(const Job& job);

std::function<bool(const OpNode* src, const OpNode* dst)> CreatePredicatorIsReachable() const;
// Print the graph with SBP in order
void PrintSBPGraphDebugInfo() const;

Expand All @@ -155,6 +158,13 @@ class OpGraph final : public Graph<OpNode, OpEdge> {
HashMap<std::string, HashSet<std::string>> producer_op_name2ctrl_consumer_op_names_;
};

class OpGraphSingletonGuard {
public:
OF_DISALLOW_COPY_AND_MOVE(OpGraphSingletonGuard);
explicit OpGraphSingletonGuard(const Job& job);
~OpGraphSingletonGuard();
};

} // namespace oneflow

#endif // ONEFLOW_CORE_GRAPH_OP_GRAPH_H_
Loading