|
| 1 | +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "paddle/fluid/framework/ir/generate_pass.h" |
| 16 | + |
| 17 | +namespace paddle { |
| 18 | +namespace framework { |
| 19 | +namespace ir { |
| 20 | + |
| 21 | +void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { |
| 22 | + const proto::BlockDesc& block = pass_desc.pattern().blocks(0); |
| 23 | + // Traverse all operators to create subgraph. |
| 24 | + for (int index = 0; index < block.ops_size(); ++index) { |
| 25 | + const proto::OpDesc& op = block.ops(index); |
| 26 | + // Create a PDNode for current operator. Use the index as name to avoid |
| 27 | + // multiple operators with same type. Get a PDNode from pattern subgraph |
| 28 | + // through index in rewrite phase. |
| 29 | + PDNode* op_pdnode = |
| 30 | + pattern->NewNode(std::to_string(index))->assert_is_op(op.type()); |
| 31 | + // Create PDNodes for inputs of current operator. |
| 32 | + for (const proto::OpDesc::Var& var : op.inputs()) { |
| 33 | + for (const std::string& argument : var.arguments()) { |
| 34 | + // The input may be the output of other operator. |
| 35 | + PDNode* var_pdnode = pattern->RetrieveNode(argument); |
| 36 | + if (nullptr == var_pdnode) { |
| 37 | + var_pdnode = pattern->NewNode(argument)->AsInput(); |
| 38 | + } else if (var_pdnode->IsOutput()) { |
| 39 | + var_pdnode->AsIntermediate(); |
| 40 | + } |
| 41 | + var_pdnode->assert_is_op_input(op.type()); |
| 42 | + pattern->AddEdge(var_pdnode, op_pdnode); |
| 43 | + } |
| 44 | + } |
| 45 | + // Create PDNodes for outputs of current operator. |
| 46 | + for (const proto::OpDesc::Var& var : op.outputs()) { |
| 47 | + for (const std::string& argument : var.arguments()) { |
| 48 | + // The output may be the input of other operator. |
| 49 | + PDNode* var_pdnode = pattern->RetrieveNode(argument); |
| 50 | + if (nullptr == var_pdnode) { |
| 51 | + var_pdnode = pattern->NewNode(argument)->AsOutput(); |
| 52 | + } else if (var_pdnode->IsInput()) { |
| 53 | + var_pdnode->AsIntermediate(); |
| 54 | + } |
| 55 | + var_pdnode->assert_is_op_output(op.type()); |
| 56 | + pattern->AddEdge(op_pdnode, var_pdnode); |
| 57 | + } |
| 58 | + } |
| 59 | + // Set attribute condition for current operator. |
| 60 | + for (const proto::OpDesc::Attr& attr : op.attrs()) { |
| 61 | + op_pdnode->assert_more([&](Node* x) { |
| 62 | + if (x && x->IsOp()) { |
| 63 | + OpDesc* op_desc = x->Op(); |
| 64 | + if (op_desc->HasAttr(attr.name())) { |
| 65 | + return GetAttrValue(attr) == op_desc->GetAttr(attr.name()); |
| 66 | + } |
| 67 | + return false; |
| 68 | + } |
| 69 | + return false; |
| 70 | + }); |
| 71 | + } |
| 72 | + } |
| 73 | +} |
| 74 | + |
| 75 | +GraphPatternDetector::handle_t GetGenerateRewrite( |
| 76 | + const PDPattern& pattern, const proto::PassDesc& pass_desc) { |
| 77 | + GraphPatternDetector::handle_t handler = [&]( |
| 78 | + const GraphPatternDetector::subgraph_t subgraph, Graph* graph) { |
| 79 | + // There are some duplicate patterns. |
| 80 | + for (auto iter : subgraph) { |
| 81 | + if (nullptr == graph->RetrieveNode(iter.second->id())) { |
| 82 | + VLOG(3) << "Node [" << iter.second->Name() |
| 83 | + << "] of subgraph has been removed. So skip this optimize."; |
| 84 | + return; |
| 85 | + } |
| 86 | + } |
| 87 | + const proto::BlockDesc& block = pass_desc.replace().blocks(0); |
| 88 | + // `var_node_maps` record the mapping of variable to the pattern subgraph. |
| 89 | + std::map<std::string, Node*> var_node_maps; |
| 90 | + for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) { |
| 91 | + Node* node = subgraph.at(pattern.RetrieveNode(var_map.pattern_var())); |
| 92 | + var_node_maps.insert({var_map.replace_var(), node}); |
| 93 | + } |
| 94 | + // Traverse all operators to create subgraph. |
| 95 | + for (const proto::OpDesc& op : block.ops()) { |
| 96 | + OpDesc op_desc; |
| 97 | + std::vector<Node *> in_nodes, out_nodes; |
| 98 | + op_desc.SetType(op.type()); |
| 99 | + // Create Nodes for inputs of current operator. |
| 100 | + for (const proto::OpDesc::Var& var : op.inputs()) { |
| 101 | + std::vector<std::string> arguments; |
| 102 | + for (const std::string& argument : var.arguments()) { |
| 103 | + // The input may be mapped on the operator of pattern subgraph. |
| 104 | + Node* node = nullptr; |
| 105 | + auto iter = var_node_maps.find(argument); |
| 106 | + if (var_node_maps.end() == iter) { |
| 107 | + VarDesc var_desc(patterns::UniqueKey(argument)); |
| 108 | + node = graph->CreateVarNode(&var_desc); |
| 109 | + var_node_maps.insert({argument, node}); |
| 110 | + } else { |
| 111 | + node = iter->second; |
| 112 | + } |
| 113 | + in_nodes.push_back(node); |
| 114 | + arguments.push_back(node->Name()); |
| 115 | + } |
| 116 | + op_desc.SetInput(var.parameter(), arguments); |
| 117 | + } |
| 118 | + // Create Nodes for outputs of current operator. |
| 119 | + for (const proto::OpDesc::Var& var : op.outputs()) { |
| 120 | + std::vector<std::string> arguments; |
| 121 | + for (const std::string& argument : var.arguments()) { |
| 122 | + // The output may be mapped on the operator of pattern subgraph. |
| 123 | + Node* node = nullptr; |
| 124 | + auto iter = var_node_maps.find(argument); |
| 125 | + if (var_node_maps.end() == iter) { |
| 126 | + VarDesc var_desc(patterns::UniqueKey(argument)); |
| 127 | + node = graph->CreateVarNode(&var_desc); |
| 128 | + var_node_maps.insert({argument, node}); |
| 129 | + } else { |
| 130 | + node = iter->second; |
| 131 | + } |
| 132 | + out_nodes.push_back(node); |
| 133 | + arguments.push_back(node->Name()); |
| 134 | + } |
| 135 | + op_desc.SetOutput(var.parameter(), arguments); |
| 136 | + } |
| 137 | + // Set attribute for current operator. |
| 138 | + for (const proto::OpDesc::Attr& attr : op.attrs()) { |
| 139 | + op_desc.SetAttr(attr.name(), GetAttrValue(attr)); |
| 140 | + } |
| 141 | + // Create a Node for current operator. |
| 142 | + Node* op_node = graph->CreateOpNode(&op_desc); |
| 143 | + for (Node* node : in_nodes) { |
| 144 | + IR_NODE_LINK_TO(node, op_node); |
| 145 | + } |
| 146 | + for (Node* node : out_nodes) { |
| 147 | + IR_NODE_LINK_TO(op_node, node); |
| 148 | + } |
| 149 | + } |
| 150 | + // Remove nodes that are intermediate. |
| 151 | + std::unordered_set<const Node*> remove_nodes; |
| 152 | + for (const std::unique_ptr<PDNode>& pdnode : pattern.nodes()) { |
| 153 | + remove_nodes.emplace(subgraph.at(pdnode.get())); |
| 154 | + } |
| 155 | + for (auto iter : var_node_maps) { |
| 156 | + remove_nodes.erase(iter.second); |
| 157 | + } |
| 158 | + GraphSafeRemoveNodes(graph, remove_nodes); |
| 159 | + }; |
| 160 | + return handler; |
| 161 | +} |
| 162 | + |
| 163 | +GeneratePass::GeneratePass(const std::string& binary_str) { |
| 164 | + multi_pass_desc_.ParseFromString(binary_str); |
| 165 | + VerifyDesc(); |
| 166 | +} |
| 167 | + |
| 168 | +GeneratePass::GeneratePass(const proto::MultiPassDesc& multi_pass_desc) |
| 169 | + : multi_pass_desc_(multi_pass_desc) { |
| 170 | + VerifyDesc(); |
| 171 | +} |
| 172 | + |
| 173 | +void GeneratePass::ApplyImpl(Graph* graph) const { |
| 174 | + for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) { |
| 175 | + GraphPatternDetector detector; |
| 176 | + InitGeneratePattern(pass_desc, detector.mutable_pattern()); |
| 177 | + detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc)); |
| 178 | + // The rewrited graph needs to be verified. Current Pass should be skipped |
| 179 | + // if validation failed. Rewrite based on the original graph cannot |
| 180 | + // implement rollback operation. |
| 181 | + VerifyGraph(*graph); |
| 182 | + } |
| 183 | +} |
| 184 | + |
| 185 | +void GeneratePass::VerifyDesc() const { |
| 186 | + PADDLE_ENFORCE_NE(multi_pass_desc_.pass_descs_size(), 0, |
| 187 | + platform::errors::InvalidArgument( |
| 188 | + "Size of PassDesc should not be empty.")); |
| 189 | + for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) { |
| 190 | + // Check inputs/outputs of subgraph should in `var_maps`. |
| 191 | + std::set<std::string> pattern_var_sets, replace_var_sets; |
| 192 | + for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) { |
| 193 | + pattern_var_sets.emplace(var_map.pattern_var()); |
| 194 | + replace_var_sets.emplace(var_map.replace_var()); |
| 195 | + } |
| 196 | + auto check_vars = [=](std::set<std::string>* var_sets, |
| 197 | + const proto::BlockDesc& block) { |
| 198 | + for (const proto::OpDesc& op : block.ops()) { |
| 199 | + for (const proto::OpDesc::Var& var : op.outputs()) { |
| 200 | + for (const std::string& argument : var.arguments()) { |
| 201 | + var_sets->emplace(argument); |
| 202 | + } |
| 203 | + } |
| 204 | + } |
| 205 | + for (const proto::OpDesc& op : block.ops()) { |
| 206 | + for (const proto::OpDesc::Var& var : op.inputs()) { |
| 207 | + for (const std::string& argument : var.arguments()) { |
| 208 | + PADDLE_ENFORCE_NE( |
| 209 | + var_sets->find(argument), var_sets->end(), |
| 210 | + platform::errors::InvalidArgument( |
| 211 | + "Subgraph of PassDesc has argument [%s] not in `var_maps`.", |
| 212 | + argument)); |
| 213 | + } |
| 214 | + } |
| 215 | + } |
| 216 | + }; |
| 217 | + check_vars(&pattern_var_sets, pass_desc.pattern().blocks(0)); |
| 218 | + check_vars(&replace_var_sets, pass_desc.replace().blocks(0)); |
| 219 | + } |
| 220 | +} |
| 221 | + |
| 222 | +bool GeneratePass::VerifyGraph(const Graph& graph) { |
| 223 | + // Return true temporarily. |
| 224 | + return true; |
| 225 | +} |
| 226 | + |
| 227 | +} // namespace ir |
| 228 | +} // namespace framework |
| 229 | +} // namespace paddle |
0 commit comments