diff --git a/paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h b/paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h index c1e6f7b6e0ac8..af22465da30a1 100644 --- a/paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h +++ b/paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h @@ -179,7 +179,6 @@ class DrrRewritePattern : public ir::OpRewritePattern { for (size_t i = 0; i < drr_output_tensors.size(); ++i) { if (!Matched) break; - // check child ops auto drr_child_ops = drr_output_tensors[i]->consumers(); auto ir_output_value = ir_node->result(i); @@ -266,9 +265,8 @@ class DrrRewritePattern : public ir::OpRewritePattern { const MatchContextImpl& src_match_ctx, ir::PatternRewriter& rewriter) const { // NOLINT MatchContextImpl res_match_ctx; - // add input tensors info for res_match_ctx; - const auto& input_tensors = result_pattern_graph.input_tensors(); - for (const auto& in_tensor : input_tensors) { + // add input tensors info for res_match_ctx + for (const auto& in_tensor : result_pattern_graph.input_tensors()) { res_match_ctx.BindIrValue( in_tensor, std::make_shared(src_match_ctx.GetIrValue(in_tensor))); @@ -277,8 +275,8 @@ class DrrRewritePattern : public ir::OpRewritePattern { // topo order visit result_pattern_graph GraphTopo graph_topo_visit(&result_pattern_graph); graph_topo_visit.WalkGraphNodesTopoOrder( - [&rewriter, &res_match_ctx](const OpCall& op_call) { - CreateOperation(op_call, rewriter, &res_match_ctx); + [&src_match_ctx, &rewriter, &res_match_ctx](const OpCall& op_call) { + CreateOperation(op_call, src_match_ctx, rewriter, &res_match_ctx); }); return res_match_ctx; diff --git a/paddle/ir/pattern_rewrite/drr/ir_operation_creator.cc b/paddle/ir/pattern_rewrite/drr/ir_operation_creator.cc new file mode 100644 index 0000000000000..b736ffbaf9ba8 --- /dev/null +++ b/paddle/ir/pattern_rewrite/drr/ir_operation_creator.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/pattern_rewrite/drr/ir_operation_creator.h" +#include "paddle/fluid/ir/dialect/pd_op.h" + +namespace ir { +namespace drr { + +Value GetIrValueByDrrTensor(const Tensor& tensor, + const MatchContextImpl& res_match_ctx) { + return res_match_ctx.GetIrValue(tensor.name()).get(); +} + +std::vector GetIrValuesByDrrTensors( + const std::vector& tensors, + const MatchContextImpl& res_match_ctx) { + std::vector ir_values; + ir_values.reserve(tensors.size()); + for (const auto* tensor : tensors) { + ir_values.push_back(GetIrValueByDrrTensor(*tensor, res_match_ctx)); + } + return ir_values; +} + +ir::AttributeMap CreateAttributeMap(const OpCall& op_call, + const MatchContextImpl& src_match_ctx) { + ir::AttributeMap attr_map; + for (const auto& kv : op_call.attributes()) { + attr_map[kv.first] = src_match_ctx.GetIrAttr(kv.second.name()); + } + return attr_map; +} + +template +T GetAttr(const std::string& attr_name, + const OpCall& op_call, + const MatchContextImpl& src_match_ctx) { + return src_match_ctx.Attr(op_call.attributes().at(attr_name).name()); +} + +Operation* CreateOperation(const OpCall& op_call, + const MatchContextImpl& src_match_ctx, + ir::PatternRewriter& rewriter, // NOLINT + MatchContextImpl* res_match_ctx) { + if (op_call.name() == "pd.reshape") { + const auto& inputs = op_call.inputs(); + std::vector ir_values = + GetIrValuesByDrrTensors(inputs, *res_match_ctx); + // TODO(zyfncg): support attr in build op. + Operation* reshape_op = rewriter.Build( + ir_values[0].dyn_cast(), + ir_values[1].dyn_cast()); + res_match_ctx->BindIrValue( + op_call.outputs()[0]->name(), + std::make_shared(reshape_op->result(0))); + res_match_ctx->BindIrValue( + op_call.outputs()[1]->name(), + std::make_shared(reshape_op->result(1))); + return reshape_op; + } else if (op_call.name() == "pd.transpose") { + const auto& inputs = op_call.inputs(); + std::vector ir_values = + GetIrValuesByDrrTensors(inputs, *res_match_ctx); + Operation* transpose_op = rewriter.Build( + ir_values[0].dyn_cast(), + GetAttr>("perm", op_call, src_match_ctx)); + res_match_ctx->BindIrValue( + op_call.outputs()[0]->name(), + std::make_shared(transpose_op->result(0))); + return transpose_op; + } + + LOG(ERROR) << "Unknown op " << op_call.name(); + return nullptr; +} + +} // namespace drr +} // namespace ir diff --git a/paddle/ir/pattern_rewrite/drr/ir_operation_creator.h b/paddle/ir/pattern_rewrite/drr/ir_operation_creator.h index 5f78701e8ee3f..c655bb4d2c126 100644 --- a/paddle/ir/pattern_rewrite/drr/ir_operation_creator.h +++ b/paddle/ir/pattern_rewrite/drr/ir_operation_creator.h @@ -20,61 +20,13 @@ #include "paddle/ir/pattern_rewrite/drr/match_context_impl.h" #include "paddle/ir/pattern_rewrite/pattern_match.h" -#include "paddle/fluid/ir/dialect/pd_op.h" - namespace ir { namespace drr { -Value GetIrValueByDrrTensor(const Tensor& tensor, - const MatchContextImpl& res_match_ctx) { - return res_match_ctx.GetIrValue(tensor.name()).get(); -} - -std::vector GetIrValuesByDrrTensors( - const std::vector& tensors, - const MatchContextImpl& res_match_ctx) { - std::vector ir_values; - ir_values.reserve(tensors.size()); - for (const auto* tensor : tensors) { - ir_values.push_back(GetIrValueByDrrTensor(*tensor, res_match_ctx)); - } - return ir_values; -} - Operation* CreateOperation(const OpCall& op_call, + const MatchContextImpl& src_match_ctx, ir::PatternRewriter& rewriter, // NOLINT - MatchContextImpl* res_match_ctx) { - if (op_call.name() == "pd.reshape") { - const auto& inputs = op_call.inputs(); - std::vector ir_values = - GetIrValuesByDrrTensors(inputs, *res_match_ctx); - // TODO(zyfncg): support attr in build op. - Operation* reshape_op = rewriter.Build( - ir_values[0].dyn_cast(), - std::vector{16, 3, 4, 16}); - res_match_ctx->BindIrValue( - op_call.outputs()[0]->name(), - std::make_shared(reshape_op->result(0))); - res_match_ctx->BindIrValue( - op_call.outputs()[1]->name(), - std::make_shared(reshape_op->result(1))); - return reshape_op; - } - else if(op_call.name() == "pd.transpose") { - const auto& inputs = op_call.inputs(); - std::vector ir_values = GetIrValuesByDrrTensors(inputs, *res_match_ctx); - Operation* transpose_op = rewriter.Build( - ir_values[0].dyn_cast(), - std::vector{0, 2, 1, 3}); - res_match_ctx->BindIrValue( - op_call.outputs()[0]->name(), - std::make_shared(transpose_op->result(0))); - return transpose_op; - } - - LOG(ERROR) << "Unknown op " << op_call.name(); - return nullptr; -} + MatchContextImpl* res_match_ctx); // template // class CreateOperation { diff --git a/paddle/ir/pattern_rewrite/drr/match_context_impl.h b/paddle/ir/pattern_rewrite/drr/match_context_impl.h index c4fd33eb1ec30..e7ec679a65607 100644 --- a/paddle/ir/pattern_rewrite/drr/match_context_impl.h +++ b/paddle/ir/pattern_rewrite/drr/match_context_impl.h @@ -40,6 +40,27 @@ PD_SPECIALIZE_CppTypeToIrAttribute(int32_t, Int32Attribute); PD_SPECIALIZE_CppTypeToIrAttribute(int64_t, Int64Attribute); PD_SPECIALIZE_CppTypeToIrAttribute(float, FloatAttribute); +template +struct IrAttrTypeCast { + static T To(const ir::Attribute& attr) { + return attr.dyn_cast::type>().data(); + } +}; + +template <> +struct IrAttrTypeCast> { + static std::vector To(const ir::Attribute& attr) { + std::vector result; + for (size_t i = 0; i < attr.dyn_cast().size(); i++) { + result.push_back(attr.dyn_cast() + .at(i) + .dyn_cast() + .data()); + } + return result; + } +}; + class MatchContextImpl final { public: MatchContextImpl() = default; @@ -54,18 +75,16 @@ class MatchContextImpl final { } template - T Attr(const std::string& attr_name) const { - return attr_map_.at(attr_name) - .dyn_cast::type>() - .data(); + T Attr(const std::string& attr_id) const { + return IrAttrTypeCast::To(attr_map_.at(attr_id)); } const IrValue& GetIrValue(const std::string& tensor_name) const { return *tensor_map_.at(tensor_name); } - ir::Attribute GetIrAttr(const std::string& tensor_name) const { - return attr_map_.at(tensor_name); + ir::Attribute GetIrAttr(const std::string& attr_id) const { + return attr_map_.at(attr_id); } const std::unordered_map>& diff --git a/test/cpp/ir/pattern_rewrite/drr_test.cc b/test/cpp/ir/pattern_rewrite/drr_test.cc index dc7326f591651..8e023c733a624 100644 --- a/test/cpp/ir/pattern_rewrite/drr_test.cc +++ b/test/cpp/ir/pattern_rewrite/drr_test.cc @@ -17,6 +17,7 @@ #include #include "paddle/fluid/ir/dialect/pd_dialect.h" +#include "paddle/fluid/ir/dialect/pd_op.h" #include "paddle/ir/pass/pass.h" #include "paddle/ir/pass/pass_manager.h" #include "paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h" @@ -141,7 +142,7 @@ void BuildProgram(ir::Builder &builder) { // NOLINT paddle::dialect::TransposeOp transpose_op2 = builder.Build(transpose_op1.out(), - std::vector{0, 1, 2, 3}); + std::vector{1, 0, 2, 3}); paddle::dialect::ReluOp relu_op_second = builder.Build(transpose_op2.out());