From 2eb532f6cafcf3c9032a0960c7bde1f0c0e51b42 Mon Sep 17 00:00:00 2001 From: linmin Date: Thu, 26 May 2016 16:44:59 +0800 Subject: [PATCH] move PostDFSVisit to graph_algorithm.h --- src/symbol/graph_algorithm.h | 31 +++++++++++++++++++++ src/symbol/graph_executor.cc | 23 ++++++++++++++-- src/symbol/static_graph.cc | 52 +++++++++++++----------------------- src/symbol/static_graph.h | 4 +-- src/symbol/symbol.cc | 43 +++++++++++++---------------- 5 files changed, 91 insertions(+), 62 deletions(-) diff --git a/src/symbol/graph_algorithm.h b/src/symbol/graph_algorithm.h index c009e28a..e301eb50 100644 --- a/src/symbol/graph_algorithm.h +++ b/src/symbol/graph_algorithm.h @@ -12,6 +12,7 @@ #include #include #include +#include #include "./static_graph.h" namespace mxnet { @@ -112,6 +113,36 @@ inline uint32_t ColorNodeGroup( } return cindex + 1; } + +template +void PostOrderDFSVisit(const std::vector& heads, FVisit fvisit, + HashFunc hash, InDegree indegree, GetInput getinput) { + std::vector > stack; + std::unordered_set visited; + for (auto& head : heads) { + HashType head_hash = hash(head); + if (visited.count(head_hash) == 0) { + stack.push_back(std::make_pair(head, 0)); + visited.insert(head_hash); + } + while (!stack.empty()) { + std::pair& back = stack.back(); + if (back.second == indegree(back.first)) { + fvisit(back.first); + stack.pop_back(); + } else { + const GNode& input = getinput(back.first, back.second++); + HashType input_hash = hash(input); + if (visited.count(input_hash) == 0) { + stack.push_back(std::make_pair(input, 0)); + visited.insert(input_hash); + } + } + } + } +} + } // namespace graph } // namespace mxnet #endif // MXNET_SYMBOL_GRAPH_ALGORITHM_H_ diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index a3195d48..8662402a 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -324,7 +324,7 @@ void GraphExecutor::InitGraph(const Symbol &symbol, for (const auto& head : graph_.heads) { head_nodes.push_back(head.source_id); } - std::vector fwd_nodes = graph_.PostDFSOrder(head_nodes, std::unordered_set()); + std::vector fwd_nodes = graph_.PostDFSOrder(head_nodes); num_forward_nodes_ = fwd_nodes.size(); std::unordered_set fwd_set(fwd_nodes.begin(), fwd_nodes.end()); @@ -348,7 +348,26 @@ void GraphExecutor::InitGraph(const Symbol &symbol, } std::unordered_set finished(fwd_nodes.begin(), fwd_nodes.end()); for (uint32_t nid : backward) { - std::vector pass = graph_.PostDFSOrder({nid}, finished); + std::vector pass; + graph::PostOrderDFSVisit( + {nid}, + [&](uint32_t n) { if (finished.count(n) == 0) { + pass.push_back(n); + }}, // FVisit + [](uint32_t n)->uint32_t { return n; }, // HashFunc + [=](uint32_t n)->uint32_t { // InDegree + if (finished.count(n) == 1) { return 0; } + const StaticGraph::Node& node = graph_.nodes[n]; + return node.inputs.size() + static_cast(node.is_backward()); + }, + [=](uint32_t n, uint32_t index)->uint32_t { // GetInput + const StaticGraph::Node& node = graph_.nodes[n]; + if (index < node.inputs.size()) { + return node.inputs.at(index).source_id; + } else { + return node.backward_source_id; + } + }); topo_order_.insert(topo_order_.end(), pass.begin(), pass.end()); finished.insert(pass.begin(), pass.end()); } diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc index bd772096..2e2e893c 100644 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -9,44 +9,30 @@ #include #include #include "./static_graph.h" +#include "./graph_algorithm.h" #include "../operator/operator_common.h" namespace mxnet { -std::vector StaticGraph::PostDFSOrder(const std::vector& head_nodes, - const std::unordered_set& banned) const { +std::vector StaticGraph::PostDFSOrder( + const std::vector& head_nodes) const { std::vector ret; - std::unordered_set visited; ret.reserve(nodes.size() / 2); - std::vector > stack; - // heads - for (auto head : head_nodes) { - if (visited.count(head) != 0) continue; - stack.push_back(std::make_pair(head, 0)); - CHECK_EQ(banned.count(head), 0); - // bugfix - visited.insert(head); - while (!stack.empty()) { - std::pair& back = stack.back(); - const Node& n = nodes[back.first]; - if (back.second == n.inputs.size() + (n.is_backward() ? 1 : 0)) { - ret.push_back(back.first); - visited.insert(back.first); - stack.pop_back(); - } else { - uint32_t input; - if (back.second == n.inputs.size() && n.is_backward()) { - input = n.backward_source_id; - back.second++; + graph::PostOrderDFSVisit( + head_nodes, + [&](uint32_t n) { ret.push_back(n); }, // FVisit + [](uint32_t n)->uint32_t { return n; }, // HashFunc + [=](uint32_t n)->uint32_t { // InDegree + return nodes[n].inputs.size() + static_cast(nodes[n].is_backward()); + }, + [=](uint32_t n, uint32_t index)->uint32_t { // GetInput + const Node& node = nodes[n]; + if (index < node.inputs.size()) { + return node.inputs.at(index).source_id; } else { - input = n.inputs[back.second++].source_id; - } - if (visited.count(input) == 0 && banned.count(input) == 0) { - stack.push_back(std::make_pair(input, 0)); + return node.backward_source_id; } - } - } - } + }); return ret; } @@ -67,7 +53,7 @@ std::vector StaticGraph::TopoSort() const { head_nodes.push_back(static_cast(i)); } } - return PostDFSOrder(head_nodes, std::unordered_set()); + return PostDFSOrder(head_nodes); } bool StaticGraph::InferNodeShapes(const std::vector &topo_order, @@ -306,7 +292,7 @@ bool StaticGraph::InferShape(std::vector *in_shape, for (const auto& head : heads) { head_nodes.push_back(head.source_id); } - std::vector fwd_nodes = PostDFSOrder(head_nodes, std::unordered_set()); + std::vector fwd_nodes = PostDFSOrder(head_nodes); uint32_t counter = 0; for (uint32_t nid : fwd_nodes) { // backward consistentcy check. @@ -357,7 +343,7 @@ bool StaticGraph::InferType(std::vector *in_type, for (const auto& head : heads) { head_nodes.push_back(head.source_id); } - std::vector fwd_nodes = PostDFSOrder(head_nodes, std::unordered_set()); + std::vector fwd_nodes = PostDFSOrder(head_nodes); uint32_t counter = 0; for (uint32_t nid : fwd_nodes) { // backward consistentcy check. diff --git a/src/symbol/static_graph.h b/src/symbol/static_graph.h index 975c063c..fe5e206b 100644 --- a/src/symbol/static_graph.h +++ b/src/symbol/static_graph.h @@ -211,9 +211,7 @@ class StaticGraph { * \param banned The banned map, used to ban some nodes from the graph. * \return a post DFS visit order of nodes that can reach heads. */ - std::vector PostDFSOrder(const std::vector& head_nodes, - const std::unordered_set& banned - = std::unordered_set()) const; + std::vector PostDFSOrder(const std::vector& head_nodes) const; /*! * \brief infer the node shapes in the computation graph. * diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index ba11f663..e3abf8c0 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -10,6 +10,7 @@ #include #include #include "./static_graph.h" +#include "./graph_algorithm.h" namespace mxnet { @@ -83,31 +84,25 @@ inline bool Symbol::is_atomic() const { // implementation of template functions template inline void Symbol::DFSVisit(FVisit fvisit) const { - std::vector*, uint32_t> > stack; - std::unordered_set visited; - // put the head into the graph - for (auto &head : heads_) { - Node* ptr = head.source.get(); - if (visited.count(ptr) == 0) { - stack.push_back(std::make_pair(&head.source, 0)); - visited.insert(ptr); - } - while (!stack.empty()) { - std::pair *, uint32_t>& back = stack.back(); - if (back.second == back.first->get()->inputs.size()) { - fvisit(*(back.first)); - stack.pop_back(); - } else { - std::vector& inputs = back.first->get()->inputs; - Symbol::DataEntry& input = inputs.at(back.second++); - Node* ptr = input.source.get(); - if (visited.count(ptr) == 0) { - stack.push_back(std::make_pair(&input.source, 0)); - visited.insert(ptr); + typedef const std::shared_ptr* GNode; + std::vector head_nodes(heads_.size()); + std::transform(heads_.begin(), heads_.end(), head_nodes.begin(), + [](const DataEntry& e)->GNode { + return &e.source; + }); + graph::PostOrderDFSVisit( + head_nodes, + [fvisit](GNode n) { fvisit(*n); }, // FVisit + [](GNode n)->Node* { return n->get(); }, // HashFunc + [](GNode n)->uint32_t { return (*n)->inputs.size() + + static_cast((*n)->is_backward()); }, // InDegree + [](GNode n, uint32_t index)->GNode { // GetInput + if (index < (*n)->inputs.size()) { + return &(*n)->inputs.at(index).source; + } else { + return &(*n)->backward_source_node; } - } - } - } + }); } // helper function to handle keyword argument mismatch