Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#13 from zyfncg/drr_ir
Browse files Browse the repository at this point in the history
[DRR] Fix bug of DRR test
  • Loading branch information
yuanlehome authored Aug 17, 2023
2 parents b929f26 + 385b7d1 commit 25d4a0a
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 83 deletions.
10 changes: 7 additions & 3 deletions paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ void Op::operator()(const Tensor& arg, const Tensor* out) const {
pattern_graph_->AddOpCall(std::make_shared<OpCall>(this, inputs, outputs));
}

void Op::operator()(const std::vector<const Tensor*>& args,
const std::vector<const Tensor*>& outputs) const {
pattern_graph_->AddOpCall(std::make_shared<OpCall>(this, args, outputs));
}

Tensor& Op::operator()(const Tensor& arg) const {
std::vector<const Tensor*> inputs{&arg};
auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr<Tensor>(new Tensor(
Expand All @@ -95,7 +100,7 @@ Tensor& Op::operator()(const Tensor& arg) const {
Tensor& Op::operator()(const Tensor& arg1, const Tensor& arg2) const {
std::vector<const Tensor*> inputs{&arg1, &arg2};
auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr<Tensor>(new Tensor(
"tmp_" + op_type_name_ + "_" + std::to_string(count++), pattern_graph_)));
"tmp_" + op_type_name_ + "_" + std::to_string(count++), pattern_graph_)));
std::vector<const Tensor*> outputs{&out};
pattern_graph_->AddOpCall(std::make_shared<OpCall>(this, inputs, outputs));
return out;
Expand All @@ -115,8 +120,7 @@ int64_t Op::count = 0;
void Tensor::operator=(Tensor& other) const { // NOLINT
// The two tensor must be in the same pattern graph.
IR_ENFORCE(this->pattern_graph_ == other.pattern_graph_);
if (other.name_.substr(0, 4) == "tmp_" &&
name_.substr(0, 4) != "tmp_") {
if (other.name_.substr(0, 4) == "tmp_" && name_.substr(0, 4) != "tmp_") {
other.pattern_graph_->UpdateTmpTensor(other.name_, this->name_);
}
}
Expand Down
29 changes: 21 additions & 8 deletions paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ class ResultPatternGraph;

class Attribute {
public:
explicit Attribute(const std::string& id) : attr_id_(id) {}
explicit Attribute(const std::string& name) : attr_name_(name) {}

const std::string& id() const { return attr_id_; }
const std::string& name() const { return attr_name_; }

private:
std::string attr_id_;
std::string attr_name_;
};

class TensorShape {
Expand Down Expand Up @@ -118,16 +118,17 @@ class Op {

Tensor& operator()(const Tensor& arg) const;
Tensor& operator()(const Tensor& arg0, const Tensor& arg1) const;
void operator()(const std::vector<const Tensor*>& args,
const std::vector<const Tensor*>& outputs) const;
// const Tensor& operator()(const Tensor& arg0, const Tensor& arg1, const
// Tensor& arg2) const; const Tensor& operator()(const Tensor& arg0, const
// Tensor& arg1, const Tensor& arg2, const Tensor& arg3) const; const Tensor&
// operator()(const Tensor& arg0, const Tensor& arg1, const Tensor& arg2,
// const Tensor& arg3, const Tensor& arg4) const;
// void operator()(const std::vector<Tensor>& args, const
// std::vector<Tensor*>& outputs) const;

private:
friend class DrrPatternContext;
friend class OpCall;

Op(const std::string& op_type_name,
const std::unordered_map<std::string, Attribute>& attributes,
Expand All @@ -136,6 +137,10 @@ class Op {
attributes_(attributes),
pattern_graph_(pattern_graph) {}

const std::unordered_map<std::string, Attribute>& attributes() const {
return attributes_;
}

static int64_t count;

std::string op_type_name_;
Expand Down Expand Up @@ -187,18 +192,26 @@ class OpCall {
OpCall(const Op* op,
const std::vector<const Tensor*>& inputs,
const std::vector<const Tensor*>& outputs)
: op_(op), inputs_(inputs), outputs_(outputs) {}
: op_name_(op->op_type_name_),
inputs_(inputs),
outputs_(outputs),
attributes_(op->attributes_) {}

const std::string& name() const { return op_->name(); }
const std::string& name() const { return op_name_; }

const std::vector<const Tensor*>& inputs() const { return inputs_; }

const std::vector<const Tensor*>& outputs() const { return outputs_; }

const std::unordered_map<std::string, Attribute>& attributes() const {
return attributes_;
}

private:
const Op* op_;
std::string op_name_;
std::vector<const Tensor*> inputs_;
std::vector<const Tensor*> outputs_;
std::unordered_map<std::string, Attribute> attributes_;
};

class ResultPattern {
Expand Down
94 changes: 68 additions & 26 deletions paddle/ir/pattern_rewrite/drr/drr_rewrite_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,23 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {

source_pattern_graph_->Print();
result_pattern_graph_->Print();
}

source_pattern_match_ctx_ = std::make_unique<MatchContextImpl>();
bool MatchAndRewrite(SourceOp op,
PatternRewriter& rewriter) const override { // NOLINT
std::shared_ptr<MatchContextImpl> src_match_ctx =
std::make_shared<MatchContextImpl>();
if (PatternGraphMatch(op, src_match_ctx)) {
PatternGraphRewrite(op, *src_match_ctx, rewriter);
return true;
}
return false;
}

bool Match(SourceOp op) const override {
private:
bool PatternGraphMatch(
SourceOp op,
const std::shared_ptr<MatchContextImpl>& source_pattern_match_ctx) const {
// Match
auto* anchor = source_pattern_graph_->AnchorNode();
IR_ENFORCE(anchor);
Expand All @@ -64,7 +76,7 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
ir_q.push(op);
drr_visited.insert(anchor);
ir_visited.insert(op);
source_pattern_match_ctx_->BindIrOperation(
source_pattern_match_ctx->BindIrOperation(
anchor, std::make_shared<IrOperation>(op));
bool Matched = true;
size_t step = 0;
Expand All @@ -84,7 +96,6 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
break;
}

//
// op's inputs
const auto& drr_input_tensors = drr_node->inputs();
auto ir_input_value_size = ir_node->num_operands();
Expand All @@ -98,13 +109,15 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
// check brother ops
auto drr_brother_ops = drr_input_tensors[i]->consumers();
auto ir_input_value = ir_node->operand(i).source();
source_pattern_match_ctx_->BindIrValue(

source_pattern_match_ctx->BindIrValue(
drr_input_tensors[i]->name(),
std::make_shared<IrValue>(ir_input_value));
if (drr_brother_ops.size() != ir_input_value.use_count()) {
Matched = false;
break;
}

for (auto* drr_brother_op : drr_brother_ops) {
if (drr_visited.count(drr_brother_op) == 0) {
std::pair<bool, ir::Operation*> found{false, nullptr};
Expand All @@ -125,7 +138,7 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
ir_q.push(found.second);
drr_visited.insert(drr_brother_op);
ir_visited.insert(found.second);
source_pattern_match_ctx_->BindIrOperation(
source_pattern_match_ctx->BindIrOperation(
drr_brother_op, std::make_shared<IrOperation>(found.second));
} else {
Matched = false;
Expand All @@ -134,6 +147,11 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
}
}

if (source_pattern_graph_->input_tensors().count(
drr_input_tensors[i]->name())) {
continue;
}

// check ancestor op
auto drr_ancestor_op = drr_input_tensors[i]->producer();
auto ir_ancestor_op = ir_input_value.GetDefiningOp();
Expand All @@ -145,12 +163,11 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
ir_q.push(ir_ancestor_op);
drr_visited.insert(drr_ancestor_op);
ir_visited.insert(ir_ancestor_op);
source_pattern_match_ctx_->BindIrOperation(
source_pattern_match_ctx->BindIrOperation(
drr_ancestor_op, std::make_shared<IrOperation>(ir_ancestor_op));
}
}

//
// op's outputs
const auto& drr_output_tensors = drr_node->outputs();
auto ir_output_value_size = ir_node->num_results();
Expand All @@ -159,18 +176,24 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
Matched = false;
break;
}

for (size_t i = 0; i < drr_output_tensors.size(); ++i) {
if (!Matched) break;
// check child ops
auto drr_child_ops = drr_output_tensors[i]->consumers();
auto ir_output_value = ir_node->result(i);
source_pattern_match_ctx_->BindIrValue(
source_pattern_match_ctx->BindIrValue(
drr_output_tensors[i]->name(),
std::make_shared<IrValue>(ir_output_value));
if (source_pattern_graph_->output_tensors().count(
drr_output_tensors[i]->name())) {
continue;
}
if (drr_child_ops.size() != ir_output_value.use_count()) {
Matched = false;
break;
}

for (auto* drr_child_op : drr_child_ops) {
if (drr_visited.count(drr_child_op) == 0) {
std::pair<bool, ir::Operation*> found{false, nullptr};
Expand All @@ -191,7 +214,7 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
ir_q.push(found.second);
drr_visited.insert(drr_child_op);
ir_visited.insert(found.second);
source_pattern_match_ctx_->BindIrOperation(
source_pattern_match_ctx->BindIrOperation(
drr_child_op, std::make_shared<IrOperation>(found.second));
} else {
Matched = false;
Expand All @@ -212,7 +235,7 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
// Matched = Matched && step == source_pattern_graph_->CountOfOpCalls();

// Constraints
MatchContext match_context{source_pattern_match_ctx_};
MatchContext match_context{source_pattern_match_ctx};
for (const auto& constraint : constraints_) {
Matched = constraint(match_context);
if (!Matched) break;
Expand All @@ -221,18 +244,20 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
return Matched;
}

void Rewrite(SourceOp op,
ir::PatternRewriter& rewriter) const override { // NOLINT
void PatternGraphRewrite(SourceOp op,
const MatchContextImpl& source_pattern_match_ctx,
ir::PatternRewriter& rewriter) const { // NOLINT
// 1. Create Operations in result_pattern_graph
MatchContextImpl res_match_ctx = CreateOperations(
*result_pattern_graph_, *source_pattern_match_ctx_, rewriter);
*result_pattern_graph_, source_pattern_match_ctx, rewriter);

// 2. Replace Output Values in source_pattern_graph by Output Values in
// result_pattern_graph
ReplaceOutputTensor(*source_pattern_match_ctx_, res_match_ctx, rewriter);
ReplaceOutputTensor(source_pattern_match_ctx, res_match_ctx, rewriter);

// 3. Delete Operations in source_pattern_graph
DeleteSourcePatternOp(*source_pattern_match_ctx_, rewriter);
DeleteSourcePatternOp(
*source_pattern_graph_, source_pattern_match_ctx, rewriter);
}

private:
Expand Down Expand Up @@ -263,25 +288,42 @@ class DrrRewritePattern : public ir::OpRewritePattern<SourceOp> {
const MatchContextImpl& res_match_ctx,
ir::PatternRewriter& rewriter) const { // NOLINT
for (const auto& output_name : source_pattern_graph_->output_tensors()) {
const auto& src_ir_tensor = src_match_ctx.GetIrValue(output_name);
const auto& res_ir_tensor = res_match_ctx.GetIrValue(output_name);
rewriter.ReplaceAllUsesWith(src_ir_tensor.get(), res_ir_tensor.get());
if (result_pattern_graph_->output_tensors().count(output_name)) {
const auto& src_ir_tensor = src_match_ctx.GetIrValue(output_name);
const auto& res_ir_tensor = res_match_ctx.GetIrValue(output_name);
rewriter.ReplaceAllUsesWith(src_ir_tensor.get(), res_ir_tensor.get());
} else {
LOG(WARNING) << "The output tensor (" << output_name
<< ") in the source_pattern_graph is not the output "
"tensor in result_pattern_graph.";
}
}
}

void DeleteSourcePatternOp(const MatchContextImpl& src_match_ctx,
void DeleteSourcePatternOp(const SourcePatternGraph& source_pattern_graph,
const MatchContextImpl& src_match_ctx,
ir::PatternRewriter& rewriter) const { // NOLINT
for (const auto& kv : src_match_ctx.operation_map()) {
rewriter.EraseOp(kv.second->get());
}
std::vector<const OpCall*> topo_order_ops;
GraphTopo graph_topo_visit(&source_pattern_graph);
graph_topo_visit.WalkGraphNodesTopoOrder(
[&topo_order_ops](const OpCall& op_call) {
topo_order_ops.push_back(&op_call);
});
// Delete Operation with topo order from output tensors.
std::for_each(
topo_order_ops.rbegin(),
topo_order_ops.rend(),
[&src_match_ctx, &rewriter](const OpCall* op_call) {
auto* op = src_match_ctx.operation_map().at(op_call)->get();
VLOG(6) << "Delete (" << op_call->name() << " @" << op_call << " :@"
<< op << ") in source_pattern_graph ";
rewriter.EraseOp(src_match_ctx.operation_map().at(op_call)->get());
});
}

private:
std::shared_ptr<SourcePatternGraph> source_pattern_graph_;
std::vector<Constraint> constraints_;
std::shared_ptr<ResultPatternGraph> result_pattern_graph_;

std::shared_ptr<MatchContextImpl> source_pattern_match_ctx_;
};

} // namespace drr
Expand Down
9 changes: 6 additions & 3 deletions paddle/ir/pattern_rewrite/drr/ir_operation_creator.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,12 @@ Operation* CreateOperation(const OpCall& op_call,
Operation* reshape_op = rewriter.Build<paddle::dialect::ReshapeOp>(
ir_values[0].dyn_cast<ir::OpResult>(),
std::vector<int64_t>{16, 3, 4, 16});
auto out = reshape_op->result(0);
res_match_ctx->BindIrValue(op_call.outputs()[0]->name(),
std::make_shared<IrValue>(out));
res_match_ctx->BindIrValue(
op_call.outputs()[0]->name(),
std::make_shared<IrValue>(reshape_op->result(0)));
res_match_ctx->BindIrValue(
op_call.outputs()[1]->name(),
std::make_shared<IrValue>(reshape_op->result(1)));
return reshape_op;
}
LOG(ERROR) << "Unknown op " << op_call.name();
Expand Down
14 changes: 13 additions & 1 deletion paddle/ir/pattern_rewrite/drr/match_context_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
#include <unordered_map>

#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h"
#include "paddle/ir/pattern_rewrite/drr/api/tensor_interface.h"
#include "paddle/ir/pattern_rewrite/drr/ir_operation.h"
#include "paddle/ir/pattern_rewrite/drr/ir_value.h"

namespace ir {
namespace drr {

class OpCall;
template <class T>
struct CppTypeToIrAttribute;

Expand Down Expand Up @@ -64,11 +64,19 @@ class MatchContextImpl final {
return *tensor_map_.at(tensor_name);
}

ir::Attribute GetIrAttr(const std::string& tensor_name) const {
return attr_map_.at(tensor_name);
}

const std::unordered_map<const OpCall*, std::shared_ptr<IrOperation>>&
operation_map() const {
return operation_map_;
}

const std::unordered_map<std::string, ir::Attribute>& attr_map() const {
return attr_map_;
}

void BindIrValue(const std::string& value_name,
const std::shared_ptr<IrValue>& value) {
tensor_map_.emplace(value_name, value);
Expand All @@ -77,6 +85,10 @@ class MatchContextImpl final {
void BindIrOperation(const OpCall* op_call,
const std::shared_ptr<IrOperation>& op) {
operation_map_.emplace(op_call, op);
const auto& attrs = op_call->attributes();
for (const auto& kv : attrs) {
BindIrAttr(kv.second.name(), op->get()->attribute(kv.first));
}
}

void BindIrAttr(const std::string& attr_name, ir::Attribute attr) {
Expand Down
Loading

0 comments on commit 25d4a0a

Please sign in to comment.