diff --git a/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.cc b/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.cc old mode 100755 new mode 100644 index 2e0ea00b76163..d03215c052da9 --- a/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.cc +++ b/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.cc @@ -83,6 +83,11 @@ void Op::operator()(const Tensor& arg, const Tensor* out) const { pattern_graph_->AddOpCall(std::make_shared(this, inputs, outputs)); } +void Op::operator()(const std::vector& args, + const std::vector& outputs) const { + pattern_graph_->AddOpCall(std::make_shared(this, args, outputs)); +} + Tensor& Op::operator()(const Tensor& arg) const { std::vector inputs{&arg}; auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr(new Tensor( @@ -95,7 +100,7 @@ Tensor& Op::operator()(const Tensor& arg) const { Tensor& Op::operator()(const Tensor& arg1, const Tensor& arg2) const { std::vector inputs{&arg1, &arg2}; auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr(new Tensor( - "tmp_" + op_type_name_ + "_" + std::to_string(count++), pattern_graph_))); + "tmp_" + op_type_name_ + "_" + std::to_string(count++), pattern_graph_))); std::vector outputs{&out}; pattern_graph_->AddOpCall(std::make_shared(this, inputs, outputs)); return out; @@ -115,8 +120,7 @@ int64_t Op::count = 0; void Tensor::operator=(Tensor& other) const { // NOLINT // The two tensor must be in the same pattern graph. IR_ENFORCE(this->pattern_graph_ == other.pattern_graph_); - if (other.name_.substr(0, 4) == "tmp_" && - name_.substr(0, 4) != "tmp_") { + if (other.name_.substr(0, 4) == "tmp_" && name_.substr(0, 4) != "tmp_") { other.pattern_graph_->UpdateTmpTensor(other.name_, this->name_); } } diff --git a/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h b/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h old mode 100755 new mode 100644 index 7c9ed09a5fb6b..7ac2e8d10e7a7 --- a/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h +++ b/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h @@ -36,12 +36,12 @@ class ResultPatternGraph; class Attribute { public: - explicit Attribute(const std::string& id) : attr_id_(id) {} + explicit Attribute(const std::string& name) : attr_name_(name) {} - const std::string& id() const { return attr_id_; } + const std::string& name() const { return attr_name_; } private: - std::string attr_id_; + std::string attr_name_; }; class TensorShape { @@ -118,16 +118,17 @@ class Op { Tensor& operator()(const Tensor& arg) const; Tensor& operator()(const Tensor& arg0, const Tensor& arg1) const; + void operator()(const std::vector& args, + const std::vector& outputs) const; // const Tensor& operator()(const Tensor& arg0, const Tensor& arg1, const // Tensor& arg2) const; const Tensor& operator()(const Tensor& arg0, const // Tensor& arg1, const Tensor& arg2, const Tensor& arg3) const; const Tensor& // operator()(const Tensor& arg0, const Tensor& arg1, const Tensor& arg2, // const Tensor& arg3, const Tensor& arg4) const; - // void operator()(const std::vector& args, const - // std::vector& outputs) const; private: friend class DrrPatternContext; + friend class OpCall; Op(const std::string& op_type_name, const std::unordered_map& attributes, @@ -136,6 +137,10 @@ class Op { attributes_(attributes), pattern_graph_(pattern_graph) {} + const std::unordered_map& attributes() const { + return attributes_; + } + static int64_t count; std::string op_type_name_; @@ -187,18 +192,26 @@ class OpCall { OpCall(const Op* op, const std::vector& inputs, const std::vector& outputs) - : op_(op), inputs_(inputs), outputs_(outputs) {} + : op_name_(op->op_type_name_), + inputs_(inputs), + outputs_(outputs), + attributes_(op->attributes_) {} - const std::string& name() const { return op_->name(); } + const std::string& name() const { return op_name_; } const std::vector& inputs() const { return inputs_; } const std::vector& outputs() const { return outputs_; } + const std::unordered_map& attributes() const { + return attributes_; + } + private: - const Op* op_; + std::string op_name_; std::vector inputs_; std::vector outputs_; + std::unordered_map attributes_; }; class ResultPattern { diff --git a/paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h b/paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h index 8b77af1cad2fd..8be586f08e8ad 100644 --- a/paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h +++ b/paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h @@ -48,11 +48,23 @@ class DrrRewritePattern : public ir::OpRewritePattern { source_pattern_graph_->Print(); result_pattern_graph_->Print(); + } - source_pattern_match_ctx_ = std::make_unique(); + bool MatchAndRewrite(SourceOp op, + PatternRewriter& rewriter) const override { // NOLINT + std::shared_ptr src_match_ctx = + std::make_shared(); + if (PatternGraphMatch(op, src_match_ctx)) { + PatternGraphRewrite(op, *src_match_ctx, rewriter); + return true; + } + return false; } - bool Match(SourceOp op) const override { + private: + bool PatternGraphMatch( + SourceOp op, + const std::shared_ptr& source_pattern_match_ctx) const { // Match auto* anchor = source_pattern_graph_->AnchorNode(); IR_ENFORCE(anchor); @@ -64,7 +76,7 @@ class DrrRewritePattern : public ir::OpRewritePattern { ir_q.push(op); drr_visited.insert(anchor); ir_visited.insert(op); - source_pattern_match_ctx_->BindIrOperation( + source_pattern_match_ctx->BindIrOperation( anchor, std::make_shared(op)); bool Matched = true; size_t step = 0; @@ -84,7 +96,6 @@ class DrrRewritePattern : public ir::OpRewritePattern { break; } - // // op's inputs const auto& drr_input_tensors = drr_node->inputs(); auto ir_input_value_size = ir_node->num_operands(); @@ -98,13 +109,15 @@ class DrrRewritePattern : public ir::OpRewritePattern { // check brother ops auto drr_brother_ops = drr_input_tensors[i]->consumers(); auto ir_input_value = ir_node->operand(i).source(); - source_pattern_match_ctx_->BindIrValue( + + source_pattern_match_ctx->BindIrValue( drr_input_tensors[i]->name(), std::make_shared(ir_input_value)); if (drr_brother_ops.size() != ir_input_value.use_count()) { Matched = false; break; } + for (auto* drr_brother_op : drr_brother_ops) { if (drr_visited.count(drr_brother_op) == 0) { std::pair found{false, nullptr}; @@ -125,7 +138,7 @@ class DrrRewritePattern : public ir::OpRewritePattern { ir_q.push(found.second); drr_visited.insert(drr_brother_op); ir_visited.insert(found.second); - source_pattern_match_ctx_->BindIrOperation( + source_pattern_match_ctx->BindIrOperation( drr_brother_op, std::make_shared(found.second)); } else { Matched = false; @@ -134,6 +147,11 @@ class DrrRewritePattern : public ir::OpRewritePattern { } } + if (source_pattern_graph_->input_tensors().count( + drr_input_tensors[i]->name())) { + continue; + } + // check ancestor op auto drr_ancestor_op = drr_input_tensors[i]->producer(); auto ir_ancestor_op = ir_input_value.GetDefiningOp(); @@ -145,12 +163,11 @@ class DrrRewritePattern : public ir::OpRewritePattern { ir_q.push(ir_ancestor_op); drr_visited.insert(drr_ancestor_op); ir_visited.insert(ir_ancestor_op); - source_pattern_match_ctx_->BindIrOperation( + source_pattern_match_ctx->BindIrOperation( drr_ancestor_op, std::make_shared(ir_ancestor_op)); } } - // // op's outputs const auto& drr_output_tensors = drr_node->outputs(); auto ir_output_value_size = ir_node->num_results(); @@ -159,18 +176,24 @@ class DrrRewritePattern : public ir::OpRewritePattern { Matched = false; break; } + for (size_t i = 0; i < drr_output_tensors.size(); ++i) { if (!Matched) break; // check child ops auto drr_child_ops = drr_output_tensors[i]->consumers(); auto ir_output_value = ir_node->result(i); - source_pattern_match_ctx_->BindIrValue( + source_pattern_match_ctx->BindIrValue( drr_output_tensors[i]->name(), std::make_shared(ir_output_value)); + if (source_pattern_graph_->output_tensors().count( + drr_output_tensors[i]->name())) { + continue; + } if (drr_child_ops.size() != ir_output_value.use_count()) { Matched = false; break; } + for (auto* drr_child_op : drr_child_ops) { if (drr_visited.count(drr_child_op) == 0) { std::pair found{false, nullptr}; @@ -191,7 +214,7 @@ class DrrRewritePattern : public ir::OpRewritePattern { ir_q.push(found.second); drr_visited.insert(drr_child_op); ir_visited.insert(found.second); - source_pattern_match_ctx_->BindIrOperation( + source_pattern_match_ctx->BindIrOperation( drr_child_op, std::make_shared(found.second)); } else { Matched = false; @@ -212,7 +235,7 @@ class DrrRewritePattern : public ir::OpRewritePattern { // Matched = Matched && step == source_pattern_graph_->CountOfOpCalls(); // Constraints - MatchContext match_context{source_pattern_match_ctx_}; + MatchContext match_context{source_pattern_match_ctx}; for (const auto& constraint : constraints_) { Matched = constraint(match_context); if (!Matched) break; @@ -221,18 +244,20 @@ class DrrRewritePattern : public ir::OpRewritePattern { return Matched; } - void Rewrite(SourceOp op, - ir::PatternRewriter& rewriter) const override { // NOLINT + void PatternGraphRewrite(SourceOp op, + const MatchContextImpl& source_pattern_match_ctx, + ir::PatternRewriter& rewriter) const { // NOLINT // 1. Create Operations in result_pattern_graph MatchContextImpl res_match_ctx = CreateOperations( - *result_pattern_graph_, *source_pattern_match_ctx_, rewriter); + *result_pattern_graph_, source_pattern_match_ctx, rewriter); // 2. Replace Output Values in source_pattern_graph by Output Values in // result_pattern_graph - ReplaceOutputTensor(*source_pattern_match_ctx_, res_match_ctx, rewriter); + ReplaceOutputTensor(source_pattern_match_ctx, res_match_ctx, rewriter); // 3. Delete Operations in source_pattern_graph - DeleteSourcePatternOp(*source_pattern_match_ctx_, rewriter); + DeleteSourcePatternOp( + *source_pattern_graph_, source_pattern_match_ctx, rewriter); } private: @@ -263,25 +288,42 @@ class DrrRewritePattern : public ir::OpRewritePattern { const MatchContextImpl& res_match_ctx, ir::PatternRewriter& rewriter) const { // NOLINT for (const auto& output_name : source_pattern_graph_->output_tensors()) { - const auto& src_ir_tensor = src_match_ctx.GetIrValue(output_name); - const auto& res_ir_tensor = res_match_ctx.GetIrValue(output_name); - rewriter.ReplaceAllUsesWith(src_ir_tensor.get(), res_ir_tensor.get()); + if (result_pattern_graph_->output_tensors().count(output_name)) { + const auto& src_ir_tensor = src_match_ctx.GetIrValue(output_name); + const auto& res_ir_tensor = res_match_ctx.GetIrValue(output_name); + rewriter.ReplaceAllUsesWith(src_ir_tensor.get(), res_ir_tensor.get()); + } else { + LOG(WARNING) << "The output tensor (" << output_name + << ") in the source_pattern_graph is not the output " + "tensor in result_pattern_graph."; + } } } - void DeleteSourcePatternOp(const MatchContextImpl& src_match_ctx, + void DeleteSourcePatternOp(const SourcePatternGraph& source_pattern_graph, + const MatchContextImpl& src_match_ctx, ir::PatternRewriter& rewriter) const { // NOLINT - for (const auto& kv : src_match_ctx.operation_map()) { - rewriter.EraseOp(kv.second->get()); - } + std::vector topo_order_ops; + GraphTopo graph_topo_visit(&source_pattern_graph); + graph_topo_visit.WalkGraphNodesTopoOrder( + [&topo_order_ops](const OpCall& op_call) { + topo_order_ops.push_back(&op_call); + }); + // Delete Operation with topo order from output tensors. + std::for_each( + topo_order_ops.rbegin(), + topo_order_ops.rend(), + [&src_match_ctx, &rewriter](const OpCall* op_call) { + auto* op = src_match_ctx.operation_map().at(op_call)->get(); + VLOG(6) << "Delete (" << op_call->name() << " @" << op_call << " :@" + << op << ") in source_pattern_graph "; + rewriter.EraseOp(src_match_ctx.operation_map().at(op_call)->get()); + }); } - private: std::shared_ptr source_pattern_graph_; std::vector constraints_; std::shared_ptr result_pattern_graph_; - - std::shared_ptr source_pattern_match_ctx_; }; } // namespace drr diff --git a/paddle/ir/pattern_rewrite/drr/ir_operation_creator.h b/paddle/ir/pattern_rewrite/drr/ir_operation_creator.h index 23d7e341e568f..ce2079408dcf8 100644 --- a/paddle/ir/pattern_rewrite/drr/ir_operation_creator.h +++ b/paddle/ir/pattern_rewrite/drr/ir_operation_creator.h @@ -52,9 +52,12 @@ Operation* CreateOperation(const OpCall& op_call, Operation* reshape_op = rewriter.Build( ir_values[0].dyn_cast(), std::vector{16, 3, 4, 16}); - auto out = reshape_op->result(0); - res_match_ctx->BindIrValue(op_call.outputs()[0]->name(), - std::make_shared(out)); + res_match_ctx->BindIrValue( + op_call.outputs()[0]->name(), + std::make_shared(reshape_op->result(0))); + res_match_ctx->BindIrValue( + op_call.outputs()[1]->name(), + std::make_shared(reshape_op->result(1))); return reshape_op; } LOG(ERROR) << "Unknown op " << op_call.name(); diff --git a/paddle/ir/pattern_rewrite/drr/match_context_impl.h b/paddle/ir/pattern_rewrite/drr/match_context_impl.h index 09e6c599eeb4b..c4fd33eb1ec30 100644 --- a/paddle/ir/pattern_rewrite/drr/match_context_impl.h +++ b/paddle/ir/pattern_rewrite/drr/match_context_impl.h @@ -18,6 +18,7 @@ #include #include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h" #include "paddle/ir/pattern_rewrite/drr/api/tensor_interface.h" #include "paddle/ir/pattern_rewrite/drr/ir_operation.h" #include "paddle/ir/pattern_rewrite/drr/ir_value.h" @@ -25,7 +26,6 @@ namespace ir { namespace drr { -class OpCall; template struct CppTypeToIrAttribute; @@ -64,11 +64,19 @@ class MatchContextImpl final { return *tensor_map_.at(tensor_name); } + ir::Attribute GetIrAttr(const std::string& tensor_name) const { + return attr_map_.at(tensor_name); + } + const std::unordered_map>& operation_map() const { return operation_map_; } + const std::unordered_map& attr_map() const { + return attr_map_; + } + void BindIrValue(const std::string& value_name, const std::shared_ptr& value) { tensor_map_.emplace(value_name, value); @@ -77,6 +85,10 @@ class MatchContextImpl final { void BindIrOperation(const OpCall* op_call, const std::shared_ptr& op) { operation_map_.emplace(op_call, op); + const auto& attrs = op_call->attributes(); + for (const auto& kv : attrs) { + BindIrAttr(kv.second.name(), op->get()->attribute(kv.first)); + } } void BindIrAttr(const std::string& attr_name, ir::Attribute attr) { diff --git a/paddle/ir/pattern_rewrite/drr/pattern_graph.cc b/paddle/ir/pattern_rewrite/drr/pattern_graph.cc old mode 100755 new mode 100644 index a9caf5486489f..23508469cb76c --- a/paddle/ir/pattern_rewrite/drr/pattern_graph.cc +++ b/paddle/ir/pattern_rewrite/drr/pattern_graph.cc @@ -25,7 +25,7 @@ namespace drr { const drr::OpCall &PatternGraph::AddOpCall( const std::shared_ptr &op_call) { owned_op_call_.push_back(op_call); - for (const auto &input : op_call->inputs()) { + for (const auto *input : op_call->inputs()) { const auto &tensor_id = input->name(); IR_ENFORCE(id2owned_tensor_.count(tensor_id)); id2owned_tensor_.at(tensor_id)->AddConsumer(op_call.get()); @@ -102,11 +102,10 @@ void PatternGraph::Print() const { } std::cout << "\n" << std::endl; - std::cout << "OpCalls:" << std::endl; for (const auto &op_call : owned_op_call_) { std::cout << " " << op_call->name() << " : "; std::cout << "inputs[ "; - for (const auto &input : op_call->inputs()) { + for (const auto *input : op_call->inputs()) { std::cout << input->name() << " "; } std::cout << "], "; @@ -120,29 +119,31 @@ void PatternGraph::Print() const { std::cout << std::endl; } - const OpCall *SourcePatternGraph::AnchorNode() const { return id2owned_tensor_.at(*output_tensors_.begin())->producer(); } - -void GraphTopo::WalkGraphNodesTopoOrder(const std::function &VisitNode) const { +void GraphTopo::WalkGraphNodesTopoOrder( + const std::function &VisitNode) const { // graph data - const std::unordered_set &inputs_tensor = graph_->input_tensors(); - const std::unordered_map> &id2owned_tensor = graph_->id2owend_tensor(); - const std::vector> &owend_opcall = graph_->owned_op_call(); + const std::unordered_set &inputs_tensor = + graph_->input_tensors(); + const std::unordered_map> &id2owned_tensor = + graph_->id2owend_tensor(); + const std::vector> &owend_opcall = + graph_->owned_op_call(); std::queue opcall_queue; - std::unordered_map> opcall_dependent; - - // init opcall_dependent; - for (const std::shared_ptr &opcall_sptr : owend_opcall){ + std::unordered_map> + opcall_dependent; - if (opcall_sptr.get()->inputs().empty()){ // opcall inputs is empty + // init opcall_dependent; + for (const std::shared_ptr &opcall_sptr : owend_opcall) { + if (opcall_sptr.get()->inputs().empty()) { // opcall inputs is empty opcall_queue.push(opcall_sptr.get()); - } - else{ - for(const auto &pre_depd_tensor : opcall_sptr.get()->inputs()){ + } else { + for (const auto &pre_depd_tensor : opcall_sptr.get()->inputs()) { opcall_dependent[opcall_sptr.get()].insert(pre_depd_tensor->name()); } } @@ -150,12 +151,13 @@ void GraphTopo::WalkGraphNodesTopoOrder(const std::functionname(); - - for(const auto &tensor_comsumer : id2owned_tensor.at(tensor_id).get()->consumers()){ + const std::string &tensor_name = + id2owned_tensor.at(tensor_id).get()->name(); + for (const auto &tensor_comsumer : + id2owned_tensor.at(tensor_id).get()->consumers()) { opcall_dependent[tensor_comsumer].erase(tensor_name); - if (opcall_dependent[tensor_comsumer].empty()){ + if (opcall_dependent[tensor_comsumer].empty()) { opcall_queue.push(tensor_comsumer); } } @@ -168,19 +170,14 @@ void GraphTopo::WalkGraphNodesTopoOrder(const std::functionoutputs()) { - - for (const auto &tensor_comsumer : output_tensor->consumers()){ - + for (const auto &tensor_comsumer : output_tensor->consumers()) { opcall_dependent[tensor_comsumer].erase(output_tensor->name()); - if (opcall_dependent[tensor_comsumer].empty()){ + if (opcall_dependent[tensor_comsumer].empty()) { opcall_queue.push(tensor_comsumer); } } - } } - - return; } } // namespace drr diff --git a/paddle/ir/pattern_rewrite/drr/pattern_graph.h b/paddle/ir/pattern_rewrite/drr/pattern_graph.h index 0c2e395db2178..92f63fa7507d3 100644 --- a/paddle/ir/pattern_rewrite/drr/pattern_graph.h +++ b/paddle/ir/pattern_rewrite/drr/pattern_graph.h @@ -20,7 +20,6 @@ #include #include #include -#include namespace ir { namespace drr { @@ -55,9 +54,14 @@ class PatternGraph { void Print() const; - const std::vector>& owned_op_call()const { return owned_op_call_; }; + const std::vector>& owned_op_call() const { + return owned_op_call_; + } - const std::unordered_map>& id2owend_tensor() const { return id2owned_tensor_; }; + const std::unordered_map>& id2owend_tensor() + const { + return id2owned_tensor_; + } protected: std::unordered_map> id2owned_tensor_; diff --git a/paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc b/paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc index 01d660159f015..d43c3fdd174d6 100644 --- a/paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc +++ b/paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc @@ -131,11 +131,16 @@ class GreedyPatternRewriteDriver : public ir::PatternRewriter { for (uint32_t i = 0; i < op->num_operands(); ++i) { AddOperandToWorklist(op->operand_source(i)); } - for (uint32_t i = 0; i < op->num_regions(); ++i) { - auto& region = op->region(i); - for (auto& block : region) { - for (auto& op_item : *block) { - RemoveFromWorklist(op_item); + + if (op->num_regions() == 0) { + RemoveFromWorklist(op); + } else { + for (uint32_t i = 0; i < op->num_regions(); ++i) { + auto& region = op->region(i); + for (auto& block : region) { + for (auto& op_item : *block) { + RemoveFromWorklist(op_item); + } } } } diff --git a/test/cpp/ir/pattern_rewrite/drr_test.cc b/test/cpp/ir/pattern_rewrite/drr_test.cc index bc12b1fbce6e9..a47e58fda0590 100644 --- a/test/cpp/ir/pattern_rewrite/drr_test.cc +++ b/test/cpp/ir/pattern_rewrite/drr_test.cc @@ -17,6 +17,7 @@ #include #include "paddle/fluid/ir/dialect/pd_dialect.h" +#include "paddle/ir/builtin_transforms/dead_code_elimination_pass.h" #include "paddle/ir/pass/pass.h" #include "paddle/ir/pass/pass_manager.h" #include "paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h" @@ -27,15 +28,16 @@ struct RemoveRedundentReshapeFunctor { void operator()(ir::drr::DrrPatternContext *ctx) { // Source patterns:待匹配的子图 ir::drr::SourcePattern pat = ctx->SourcePattern(); - const auto &reshape = pat.Op("reshape"); - - pat.Tensor("ret") = reshape(reshape(pat.Tensor("arg0"), pat.Tensor("shape0")), pat.Tensor("shape1")); + const auto &reshape = pat.Op("pd.reshape"); + reshape({&pat.Tensor("arg0"), &pat.Tensor("shape0")}, + {&pat.Tensor("out1"), &pat.Tensor("xshape_0")}); + reshape({&pat.Tensor("out1"), &pat.Tensor("shape1")}, + {&pat.Tensor("ret"), &pat.Tensor("xshape_1")}); // Result patterns:要替换为的子图 ir::drr::ResultPattern res = pat.ResultPattern(); - - // - res.Tensor("ret") = res.Op("reshape")(res.Tensor("arg0"), res.Tensor("shape1")); + res.Op("pd.reshape")({&res.Tensor("arg0"), &res.Tensor("shape1")}, + {&res.Tensor("ret"), &res.Tensor("xshape_1")}); } }; @@ -147,6 +149,7 @@ TEST(DrrTest, drr) { ir::PassManager pm(ctx); pm.AddPass(std::make_unique()); + pm.AddPass(ir::CreateDeadCodeEliminationPass()); pm.EnablePassTiming(); pm.EnableIRPrinting();