Skip to content

Commit

Permalink
Merge pull request #2252 from mavenlin/dfs
Browse files Browse the repository at this point in the history
move PostDFSVisit to graph_algorithm.h
  • Loading branch information
antinucleon committed May 26, 2016
2 parents 6fbad56 + 2eb532f commit cbe23ac
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 62 deletions.
31 changes: 31 additions & 0 deletions src/symbol/graph_algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <dmlc/logging.h>
#include <mxnet/symbolic.h>
#include <vector>
#include <utility>
#include "./static_graph.h"

namespace mxnet {
Expand Down Expand Up @@ -112,6 +113,36 @@ inline uint32_t ColorNodeGroup(
}
return cindex + 1;
}

template <typename GNode, typename HashType, typename FVisit,
typename HashFunc, typename InDegree, typename GetInput>
void PostOrderDFSVisit(const std::vector<GNode>& heads, FVisit fvisit,
HashFunc hash, InDegree indegree, GetInput getinput) {
std::vector<std::pair<GNode, uint32_t> > stack;
std::unordered_set<HashType> visited;
for (auto& head : heads) {
HashType head_hash = hash(head);
if (visited.count(head_hash) == 0) {
stack.push_back(std::make_pair(head, 0));
visited.insert(head_hash);
}
while (!stack.empty()) {
std::pair<GNode, uint32_t>& back = stack.back();
if (back.second == indegree(back.first)) {
fvisit(back.first);
stack.pop_back();
} else {
const GNode& input = getinput(back.first, back.second++);
HashType input_hash = hash(input);
if (visited.count(input_hash) == 0) {
stack.push_back(std::make_pair(input, 0));
visited.insert(input_hash);
}
}
}
}
}

} // namespace graph
} // namespace mxnet
#endif // MXNET_SYMBOL_GRAPH_ALGORITHM_H_
23 changes: 21 additions & 2 deletions src/symbol/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ void GraphExecutor::InitGraph(const Symbol &symbol,
for (const auto& head : graph_.heads) {
head_nodes.push_back(head.source_id);
}
std::vector<uint32_t> fwd_nodes = graph_.PostDFSOrder(head_nodes, std::unordered_set<uint32_t>());
std::vector<uint32_t> fwd_nodes = graph_.PostDFSOrder(head_nodes);
num_forward_nodes_ = fwd_nodes.size();

std::unordered_set<uint32_t> fwd_set(fwd_nodes.begin(), fwd_nodes.end());
Expand All @@ -348,7 +348,26 @@ void GraphExecutor::InitGraph(const Symbol &symbol,
}
std::unordered_set<uint32_t> finished(fwd_nodes.begin(), fwd_nodes.end());
for (uint32_t nid : backward) {
std::vector<uint32_t> pass = graph_.PostDFSOrder({nid}, finished);
std::vector<uint32_t> pass;
graph::PostOrderDFSVisit<uint32_t, uint32_t>(
{nid},
[&](uint32_t n) { if (finished.count(n) == 0) {
pass.push_back(n);
}}, // FVisit
[](uint32_t n)->uint32_t { return n; }, // HashFunc
[=](uint32_t n)->uint32_t { // InDegree
if (finished.count(n) == 1) { return 0; }
const StaticGraph::Node& node = graph_.nodes[n];
return node.inputs.size() + static_cast<uint32_t>(node.is_backward());
},
[=](uint32_t n, uint32_t index)->uint32_t { // GetInput
const StaticGraph::Node& node = graph_.nodes[n];
if (index < node.inputs.size()) {
return node.inputs.at(index).source_id;
} else {
return node.backward_source_id;
}
});
topo_order_.insert(topo_order_.end(), pass.begin(), pass.end());
finished.insert(pass.begin(), pass.end());
}
Expand Down
52 changes: 19 additions & 33 deletions src/symbol/static_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,44 +9,30 @@
#include <queue>
#include <map>
#include "./static_graph.h"
#include "./graph_algorithm.h"
#include "../operator/operator_common.h"

namespace mxnet {

std::vector<uint32_t> StaticGraph::PostDFSOrder(const std::vector<uint32_t>& head_nodes,
const std::unordered_set<uint32_t>& banned) const {
std::vector<uint32_t> StaticGraph::PostDFSOrder(
const std::vector<uint32_t>& head_nodes) const {
std::vector<uint32_t> ret;
std::unordered_set<uint32_t> visited;
ret.reserve(nodes.size() / 2);
std::vector<std::pair<uint32_t, uint32_t> > stack;
// heads
for (auto head : head_nodes) {
if (visited.count(head) != 0) continue;
stack.push_back(std::make_pair(head, 0));
CHECK_EQ(banned.count(head), 0);
// bugfix
visited.insert(head);
while (!stack.empty()) {
std::pair<uint32_t, uint32_t>& back = stack.back();
const Node& n = nodes[back.first];
if (back.second == n.inputs.size() + (n.is_backward() ? 1 : 0)) {
ret.push_back(back.first);
visited.insert(back.first);
stack.pop_back();
} else {
uint32_t input;
if (back.second == n.inputs.size() && n.is_backward()) {
input = n.backward_source_id;
back.second++;
graph::PostOrderDFSVisit<uint32_t, uint32_t>(
head_nodes,
[&](uint32_t n) { ret.push_back(n); }, // FVisit
[](uint32_t n)->uint32_t { return n; }, // HashFunc
[=](uint32_t n)->uint32_t { // InDegree
return nodes[n].inputs.size() + static_cast<uint32_t>(nodes[n].is_backward());
},
[=](uint32_t n, uint32_t index)->uint32_t { // GetInput
const Node& node = nodes[n];
if (index < node.inputs.size()) {
return node.inputs.at(index).source_id;
} else {
input = n.inputs[back.second++].source_id;
}
if (visited.count(input) == 0 && banned.count(input) == 0) {
stack.push_back(std::make_pair(input, 0));
return node.backward_source_id;
}
}
}
}
});
return ret;
}

Expand All @@ -67,7 +53,7 @@ std::vector<uint32_t> StaticGraph::TopoSort() const {
head_nodes.push_back(static_cast<uint32_t>(i));
}
}
return PostDFSOrder(head_nodes, std::unordered_set<uint32_t>());
return PostDFSOrder(head_nodes);
}

bool StaticGraph::InferNodeShapes(const std::vector<uint32_t> &topo_order,
Expand Down Expand Up @@ -306,7 +292,7 @@ bool StaticGraph::InferShape(std::vector<TShape> *in_shape,
for (const auto& head : heads) {
head_nodes.push_back(head.source_id);
}
std::vector<uint32_t> fwd_nodes = PostDFSOrder(head_nodes, std::unordered_set<uint32_t>());
std::vector<uint32_t> fwd_nodes = PostDFSOrder(head_nodes);
uint32_t counter = 0;
for (uint32_t nid : fwd_nodes) {
// backward consistentcy check.
Expand Down Expand Up @@ -357,7 +343,7 @@ bool StaticGraph::InferType(std::vector<int> *in_type,
for (const auto& head : heads) {
head_nodes.push_back(head.source_id);
}
std::vector<uint32_t> fwd_nodes = PostDFSOrder(head_nodes, std::unordered_set<uint32_t>());
std::vector<uint32_t> fwd_nodes = PostDFSOrder(head_nodes);
uint32_t counter = 0;
for (uint32_t nid : fwd_nodes) {
// backward consistentcy check.
Expand Down
4 changes: 1 addition & 3 deletions src/symbol/static_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,7 @@ class StaticGraph {
* \param banned The banned map, used to ban some nodes from the graph.
* \return a post DFS visit order of nodes that can reach heads.
*/
std::vector<uint32_t> PostDFSOrder(const std::vector<uint32_t>& head_nodes,
const std::unordered_set<uint32_t>& banned
= std::unordered_set<uint32_t>()) const;
std::vector<uint32_t> PostDFSOrder(const std::vector<uint32_t>& head_nodes) const;
/*!
* \brief infer the node shapes in the computation graph.
*
Expand Down
43 changes: 19 additions & 24 deletions src/symbol/symbol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <unordered_map>
#include <unordered_set>
#include "./static_graph.h"
#include "./graph_algorithm.h"

namespace mxnet {

Expand Down Expand Up @@ -83,31 +84,25 @@ inline bool Symbol::is_atomic() const {
// implementation of template functions
template<typename FVisit>
inline void Symbol::DFSVisit(FVisit fvisit) const {
std::vector<std::pair<const std::shared_ptr<Node>*, uint32_t> > stack;
std::unordered_set<Node*> visited;
// put the head into the graph
for (auto &head : heads_) {
Node* ptr = head.source.get();
if (visited.count(ptr) == 0) {
stack.push_back(std::make_pair(&head.source, 0));
visited.insert(ptr);
}
while (!stack.empty()) {
std::pair<const std::shared_ptr<Node> *, uint32_t>& back = stack.back();
if (back.second == back.first->get()->inputs.size()) {
fvisit(*(back.first));
stack.pop_back();
} else {
std::vector<Symbol::DataEntry>& inputs = back.first->get()->inputs;
Symbol::DataEntry& input = inputs.at(back.second++);
Node* ptr = input.source.get();
if (visited.count(ptr) == 0) {
stack.push_back(std::make_pair(&input.source, 0));
visited.insert(ptr);
typedef const std::shared_ptr<Node>* GNode;
std::vector<GNode> head_nodes(heads_.size());
std::transform(heads_.begin(), heads_.end(), head_nodes.begin(),
[](const DataEntry& e)->GNode {
return &e.source;
});
graph::PostOrderDFSVisit<GNode, Node*>(
head_nodes,
[fvisit](GNode n) { fvisit(*n); }, // FVisit
[](GNode n)->Node* { return n->get(); }, // HashFunc
[](GNode n)->uint32_t { return (*n)->inputs.size() +
static_cast<int>((*n)->is_backward()); }, // InDegree
[](GNode n, uint32_t index)->GNode { // GetInput
if (index < (*n)->inputs.size()) {
return &(*n)->inputs.at(index).source;
} else {
return &(*n)->backward_source_node;
}
}
}
}
});
}

// helper function to handle keyword argument mismatch
Expand Down

0 comments on commit cbe23ac

Please sign in to comment.