Skip to content

Commit

Permalink
Merge pull request #11 from gongshaotian/op_extension
Browse files Browse the repository at this point in the history
Overloading the operator() method of Op, supporting double tensor inputs
  • Loading branch information
yuanlehome committed Aug 15, 2023
2 parents 98e5d52 + a98230c commit b929f26
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
9 changes: 9 additions & 0 deletions paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ Tensor& Op::operator()(const Tensor& arg) const {
return out;
}

Tensor& Op::operator()(const Tensor& arg1, const Tensor& arg2) const {
std::vector<const Tensor*> inputs{&arg1, &arg2};
auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr<Tensor>(new Tensor(
"tmp_" + op_type_name_ + "_" + std::to_string(count++), pattern_graph_)));
std::vector<const Tensor*> outputs{&out};
pattern_graph_->AddOpCall(std::make_shared<OpCall>(this, inputs, outputs));
return out;
}

Tensor& Op::operator()() const {
std::vector<const Tensor*> inputs{};
auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr<Tensor>(new Tensor(
Expand Down
2 changes: 1 addition & 1 deletion paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class Op {
Tensor& operator()() const;

Tensor& operator()(const Tensor& arg) const;
// const Tensor& operator()(const Tensor& arg0, const Tensor& arg1) const;
Tensor& operator()(const Tensor& arg0, const Tensor& arg1) const;
// const Tensor& operator()(const Tensor& arg0, const Tensor& arg1, const
// Tensor& arg2) const; const Tensor& operator()(const Tensor& arg0, const
// Tensor& arg1, const Tensor& arg2, const Tensor& arg3) const; const Tensor&
Expand Down
7 changes: 5 additions & 2 deletions test/cpp/ir/pattern_rewrite/drr_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ struct RemoveRedundentReshapeFunctor {
// Source patterns:待匹配的子图
ir::drr::SourcePattern pat = ctx->SourcePattern();
const auto &reshape = pat.Op("reshape");
pat.Tensor("ret") = reshape(reshape(pat.Tensor("arg0")));

pat.Tensor("ret") = reshape(reshape(pat.Tensor("arg0"), pat.Tensor("shape0")), pat.Tensor("shape1"));

// Result patterns:要替换为的子图
ir::drr::ResultPattern res = pat.ResultPattern();
res.Tensor("ret") = res.Op("reshape")(res.Tensor("arg0"));

//
res.Tensor("ret") = res.Op("reshape")(res.Tensor("arg0"), res.Tensor("shape1"));
}
};

Expand Down

0 comments on commit b929f26

Please sign in to comment.