Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#11 from Fridge003/multi-down
Browse files Browse the repository at this point in the history
add reverse topo search algorithm for op fusion
  • Loading branch information
feifei-111 authored Apr 26, 2024
2 parents e39186c + 7d2e686 commit 18c5f51
Showing 1 changed file with 64 additions and 3 deletions.
67 changes: 64 additions & 3 deletions paddle/cinn/operator_fusion/graph_transformer/search_algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,70 @@ struct SearchAlgorithm<ReverseTopoNodePairPattern,
Phrase,
GraphMatcher,
GraphOperation> {
// TODO(@wuzhanfei)
explicit SearchAlgorithm(PatternGraph<Phrase>* graph) {}
void operator()() {}
PatternGraph<Phrase>* graph_;
std::queue<PatternNodePtr<Phrase>> reverse_topo_nodes;

explicit SearchAlgorithm(PatternGraph<Phrase>* graph) {
VLOG(4) << "Create ReverseTopoNodePairPattern algorithm.";
graph_ = graph;

// Do reverse topological sort, and store the results in reverse_topo_nodes.
std::unordered_map<PatternNodePtr<Phrase>, int>
unvisited_nodes_to_out_degree;
for (const auto& node_ptr : graph->all_pattern_nodes()) {
unvisited_nodes_to_out_degree[node_ptr] = node_ptr->downstream().size();
}

while (!unvisited_nodes_to_out_degree.empty()) {
const auto& it =
std::find_if(unvisited_nodes_to_out_degree.begin(),
unvisited_nodes_to_out_degree.end(),
[&](const auto& pair) { return pair.second == 0; });
reverse_topo_nodes.push(it->first);
for (const auto& upstream : it->first->upstream()) {
--unvisited_nodes_to_out_degree[upstream];
}
unvisited_nodes_to_out_degree.erase(it);
}
}

std::optional<std::pair<PatternNodePtr<Phrase>, PatternNodePtr<Phrase>>>
FindMatchedPair() {
// Keep picking the front element of reverse_topo_nodes as candidate of
// upstream node. Please make sure that the downstream node is merged into
// the upstream node during merging, and the upstream node will not
// disappear after merging, else the logic here should be modified.
while (!reverse_topo_nodes.empty()) {
const auto& upstream_candidate = reverse_topo_nodes.front();

// If the node has downstream, try searching for its candidate downstream
// using GraphMatcher.
if (!upstream_candidate->downstream().empty()) {
for (const auto& downstream_candidate :
upstream_candidate->downstream()) {
if (GraphMatcher()(
*graph_, upstream_candidate, downstream_candidate)) {
VLOG(4) << "Find Matched Node Pair: (" << upstream_candidate << ", "
<< downstream_candidate << ")";
return std::make_pair(upstream_candidate, downstream_candidate);
}
}
}
reverse_topo_nodes.pop();
}

VLOG(4) << "Can't find matched node any more.";
return {};
}

void operator()() {
while (true) {
const auto& node = FindMatchedPair();
if (!node.has_value()) break;
const auto& [upstream, downstream] = node.value();
GraphOperation()(graph_, upstream, downstream);
}
}
};

template <typename Kind,
Expand Down

0 comments on commit 18c5f51

Please sign in to comment.