From a98230c89c0d19ede52c3306827d0f99d7a593b3 Mon Sep 17 00:00:00 2001 From: gongshaotian Date: Tue, 15 Aug 2023 02:59:10 +0000 Subject: [PATCH] Overloading the operator() method of Op, supporting dual tensor inputs --- paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.cc | 9 +++++++++ paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h | 2 +- test/cpp/ir/pattern_rewrite/drr_test.cc | 7 +++++-- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.cc b/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.cc index a8c99ce360350..2e0ea00b76163 100755 --- a/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.cc +++ b/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.cc @@ -92,6 +92,15 @@ Tensor& Op::operator()(const Tensor& arg) const { return out; } +Tensor& Op::operator()(const Tensor& arg1, const Tensor& arg2) const { + std::vector inputs{&arg1, &arg2}; + auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr(new Tensor( + "tmp_" + op_type_name_ + "_" + std::to_string(count++), pattern_graph_))); + std::vector outputs{&out}; + pattern_graph_->AddOpCall(std::make_shared(this, inputs, outputs)); + return out; +} + Tensor& Op::operator()() const { std::vector inputs{}; auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr(new Tensor( diff --git a/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h b/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h index efe7f1866483e..7c9ed09a5fb6b 100755 --- a/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h +++ b/paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h @@ -117,7 +117,7 @@ class Op { Tensor& operator()() const; Tensor& operator()(const Tensor& arg) const; - // const Tensor& operator()(const Tensor& arg0, const Tensor& arg1) const; + Tensor& operator()(const Tensor& arg0, const Tensor& arg1) const; // const Tensor& operator()(const Tensor& arg0, const Tensor& arg1, const // Tensor& arg2) const; const Tensor& operator()(const Tensor& arg0, const // Tensor& arg1, const Tensor& arg2, const Tensor& arg3) const; const Tensor& diff --git a/test/cpp/ir/pattern_rewrite/drr_test.cc b/test/cpp/ir/pattern_rewrite/drr_test.cc index 669c3ef27f3b6..bc12b1fbce6e9 100644 --- a/test/cpp/ir/pattern_rewrite/drr_test.cc +++ b/test/cpp/ir/pattern_rewrite/drr_test.cc @@ -28,11 +28,14 @@ struct RemoveRedundentReshapeFunctor { // Source patterns:待匹配的子图 ir::drr::SourcePattern pat = ctx->SourcePattern(); const auto &reshape = pat.Op("reshape"); - pat.Tensor("ret") = reshape(reshape(pat.Tensor("arg0"))); + + pat.Tensor("ret") = reshape(reshape(pat.Tensor("arg0"), pat.Tensor("shape0")), pat.Tensor("shape1")); // Result patterns:要替换为的子图 ir::drr::ResultPattern res = pat.ResultPattern(); - res.Tensor("ret") = res.Op("reshape")(res.Tensor("arg0")); + + // + res.Tensor("ret") = res.Op("reshape")(res.Tensor("arg0"), res.Tensor("shape1")); } };