Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#25 from zyfncg/drr_pass
Browse files Browse the repository at this point in the history
[DRR] Support subgraph replace in source pattern graph for drr
  • Loading branch information
yuanlehome authored Sep 11, 2023
2 parents 7260c1e + 9c9c16b commit 0833a06
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 59 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/ir/drr/api/match_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ template bool MatchContext::Attr<bool>(const std::string&) const;
template int32_t MatchContext::Attr<int32_t>(const std::string&) const;
template int64_t MatchContext::Attr<int64_t>(const std::string&) const;
template float MatchContext::Attr<float>(const std::string&) const;
template std::string MatchContext::Attr<std::string>(const std::string&) const;
template std::vector<int32_t> MatchContext::Attr<std::vector<int32_t>>(
const std::string&) const;
template std::vector<int64_t> MatchContext::Attr<std::vector<int64_t>>(
Expand Down
13 changes: 11 additions & 2 deletions paddle/fluid/ir/drr/attr_type_uilts.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@ struct CppTypeToIrAttribute;

#define PD_SPECIALIZE_CppTypeToIrAttribute(cpp_type, ir_attr_type) \
template <> \
struct CppTypeToIrAttribute<cpp_type> { \
struct CppTypeToIrAttribute< \
std::remove_const_t<std::remove_reference_t<cpp_type>>> { \
using type = ir_attr_type; \
};

PD_SPECIALIZE_CppTypeToIrAttribute(bool, BoolAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(int32_t, Int32Attribute);
PD_SPECIALIZE_CppTypeToIrAttribute(int64_t, Int64Attribute);
PD_SPECIALIZE_CppTypeToIrAttribute(float, FloatAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(const std::string&, StrAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(std::string, StrAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(phi::DataType,
paddle::dialect::DataTypeAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(phi::Place, paddle::dialect::PlaceAttribute);
Expand All @@ -59,6 +60,14 @@ struct IrAttrTypeCast {
}
};

template <>
struct IrAttrTypeCast<std::string> {
static std::string To(const ir::Attribute& attr) {
return attr.dyn_cast<typename CppTypeToIrAttribute<std::string>::type>()
.AsString();
}
};

template <>
struct IrAttrTypeCast<std::vector<int32_t>> {
static std::vector<int32_t> To(const ir::Attribute& attr) {
Expand Down
125 changes: 102 additions & 23 deletions paddle/fluid/ir/drr/drr_rewrite_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,18 @@ class DrrRewritePattern : public ir::RewritePattern {
source_pattern_match_ctx->BindIrValue(
drr_input_tensors[i]->name(),
std::make_shared<IrValue>(ir_input_value));

// Input tensor is optional(or none)
if (!ir_input_value) {
if (drr_brother_ops.size() != 1) { // Only used by current op
matched = false;
VLOG(6) << " --- match false: drr_brother_ops is "
<< drr_brother_ops.size()
<< ", but ir_input_value is null ";
}
continue;
}

if (drr_brother_ops.size() != ir_input_value.use_count()) {
matched = false;
VLOG(6) << " --- match false: " << drr_brother_ops.size()
Expand Down Expand Up @@ -276,20 +288,25 @@ class DrrRewritePattern : public ir::RewritePattern {
void PatternGraphRewrite(const MatchContextImpl& source_pattern_match_ctx,
ir::PatternRewriter& rewriter) const { // NOLINT
VLOG(6) << "Create Operations in result_pattern_graph";
MatchContextImpl res_match_ctx = CreateOperations(
*result_pattern_graph_, source_pattern_match_ctx, rewriter);
MatchContextImpl res_match_ctx = CreateOperations(*source_pattern_graph_,
*result_pattern_graph_,
source_pattern_match_ctx,
rewriter);
VLOG(6) << "Process Assign Tensor";
RebindIrTensorForAssignTensor(*result_pattern_graph_, &res_match_ctx);
VLOG(6) << "Replace Output Values in source_pattern_graph by Output Values "
"in result_pattern_graph";
ReplaceOutputTensor(source_pattern_match_ctx, res_match_ctx, rewriter);
VLOG(6) << "Delete Operations in source_pattern_graph";
DeleteSourcePatternOp(
*source_pattern_graph_, source_pattern_match_ctx, rewriter);
DeleteSourcePatternOp(*source_pattern_graph_,
*result_pattern_graph_,
source_pattern_match_ctx,
rewriter);
}

private:
MatchContextImpl CreateOperations(
const SourcePatternGraph& source_pattern_graph,
const ResultPatternGraph& result_pattern_graph,
const MatchContextImpl& src_match_ctx,
ir::PatternRewriter& rewriter) const { // NOLINT
Expand All @@ -306,12 +323,22 @@ class DrrRewritePattern : public ir::RewritePattern {
}
}

// set insert point
for (const auto& output : result_pattern_graph.output_tensors()) {
if (source_pattern_graph.id2owend_tensor().count(output)) {
auto ir_value = src_match_ctx.GetIrValue(output);
if (ir_value.get()) {
rewriter.SetInsertionPointAfter(ir_value.get().GetDefiningOp());
break;
}
}
}

// topo order visit result_pattern_graph
GraphTopo graph_topo_visit(&result_pattern_graph);
graph_topo_visit.WalkGraphNodesTopoOrder(
[&src_match_ctx, &rewriter, &res_match_ctx](const OpCall& op_call) {
CreateOperation(op_call, src_match_ctx, rewriter, &res_match_ctx);
});
graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) {
CreateOperation(op_call, src_match_ctx, rewriter, &res_match_ctx);
});

return res_match_ctx;
}
Expand All @@ -333,20 +360,21 @@ class DrrRewritePattern : public ir::RewritePattern {
void ReplaceOutputTensor(const MatchContextImpl& src_match_ctx,
const MatchContextImpl& res_match_ctx,
ir::PatternRewriter& rewriter) const { // NOLINT
for (const auto& output_name : source_pattern_graph_->output_tensors()) {
if (result_pattern_graph_->output_tensors().count(output_name)) {
for (const auto& output_name : result_pattern_graph_->output_tensors()) {
if (source_pattern_graph_->output_tensors().count(output_name)) {
const auto& src_ir_tensor = src_match_ctx.GetIrValue(output_name);
const auto& res_ir_tensor = res_match_ctx.GetIrValue(output_name);
rewriter.ReplaceAllUsesWith(src_ir_tensor.get(), res_ir_tensor.get());
} else {
LOG(WARNING) << "The output tensor (" << output_name
<< ") in the source_pattern_graph is not the output "
"tensor in result_pattern_graph.";
<< ") in the result_pattern_graph is not the tensor"
" in source_pattern_graph.";
}
}
}

void DeleteSourcePatternOp(const SourcePatternGraph& source_pattern_graph,
const ResultPatternGraph& result_pattern_graph,
const MatchContextImpl& src_match_ctx,
ir::PatternRewriter& rewriter) const { // NOLINT
std::vector<const OpCall*> topo_order_ops;
Expand All @@ -355,18 +383,69 @@ class DrrRewritePattern : public ir::RewritePattern {
[&topo_order_ops](const OpCall& op_call) {
topo_order_ops.push_back(&op_call);
});
// Delete Operation with topo order from output tensors.
std::for_each(topo_order_ops.rbegin(),
topo_order_ops.rend(),
[&src_match_ctx, &rewriter](const OpCall* op_call) {
IR_ENFORCE(src_match_ctx.operation_map().count(op_call),
"Drr OpCall [%s] must exists in match context.",
op_call->name());
auto* op = src_match_ctx.operation_map().at(op_call)->get();
VLOG(6) << "Delete (" << op_call->name() << " @" << op_call
<< " :@" << op << ") in source_pattern_graph ";
rewriter.EraseOp(op);

// Filter the operations which are replaced by result pattern
// 1. Filter operations by forward walk
std::unordered_set<std::string> forward_visited_tensor_set(
result_pattern_graph.input_tensors());
std::unordered_set<const OpCall*> forward_deleted_ops;
std::for_each(topo_order_ops.begin(),
topo_order_ops.end(),
[&forward_deleted_ops,
&forward_visited_tensor_set](const OpCall* op_call) {
for (const auto* input : op_call->inputs()) {
if (forward_visited_tensor_set.count(input->name())) {
forward_deleted_ops.insert(op_call);
for (const auto* output : op_call->outputs()) {
forward_visited_tensor_set.insert(output->name());
}
break;
}
}
});
// 2. Filter operations by backward walk and merge the forward result
std::unordered_set<std::string> backward_visited_tensor_set(
result_pattern_graph.output_tensors());
std::vector<const OpCall*> deleted_ops;
std::unordered_set<const OpCall*> deleted_ops_set;
std::for_each(
topo_order_ops.rbegin(),
topo_order_ops.rend(),
[&deleted_ops,
&deleted_ops_set,
&backward_visited_tensor_set,
&forward_deleted_ops](const OpCall* op_call) {
bool all_comsumer_deleted = true;
for (const auto* output : op_call->outputs()) {
if (backward_visited_tensor_set.count(output->name())) {
for (const auto* consumer : output->consumers()) {
if (!deleted_ops_set.count(consumer)) {
all_comsumer_deleted = false;
}
}
} else {
all_comsumer_deleted = false;
}
}
if (all_comsumer_deleted && forward_deleted_ops.count(op_call)) {
deleted_ops_set.insert(op_call);
deleted_ops.push_back(op_call);
for (const auto* input : op_call->inputs()) {
backward_visited_tensor_set.insert(input->name());
}
}
});

// Delete Operation with topo order from output tensors.
for (const auto* op_call : deleted_ops) {
IR_ENFORCE(src_match_ctx.operation_map().count(op_call),
"Drr OpCall [%s] must exists in match context.",
op_call->name());
auto* op = src_match_ctx.operation_map().at(op_call)->get();
VLOG(6) << "Delete (" << op_call->name() << " @" << op_call << " :@" << op
<< ") in source_pattern_graph ";
rewriter.EraseOp(op);
}
}

const std::shared_ptr<SourcePatternGraph> source_pattern_graph_;
Expand Down
48 changes: 23 additions & 25 deletions paddle/fluid/ir/drr/ir_operation_creator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,9 @@ static ir::Attribute CreateIrAttribute(const std::any& obj) {
} else if (obj.type() == typeid(float)) {
return IrAttrbuteCreator<float>()(std::any_cast<float>(obj));
} else if (obj.type() == typeid(std::string)) {
return IrAttrbuteCreator<const std::string&>()(
std::any_cast<std::string>(obj));
return IrAttrbuteCreator<std::string>()(std::any_cast<std::string>(obj));
} else if (obj.type() == typeid(const char*)) {
return IrAttrbuteCreator<const std::string&>()(
std::any_cast<const char*>(obj));
return IrAttrbuteCreator<std::string>()(std::any_cast<const char*>(obj));
} else if (obj.type() == typeid(phi::DataType)) {
return IrAttrbuteCreator<phi::DataType>()(
std::any_cast<phi::DataType>(obj));
Expand Down Expand Up @@ -192,27 +190,27 @@ Operation* CreateOperation(const OpCall& op_call,
CreateAttributeMap(op_call, src_match_ctx));
res_match_ctx->BindIrValue(op_call.outputs()[0]->name(),
std::make_shared<IrValue>(op->result(0)));
// } else if (op_call.name() == "pd.fused_gemm_epilogue") {
// const auto& inputs = op_call.inputs();
// std::vector<Value> ir_values =
// GetIrValuesByDrrTensors(inputs, *res_match_ctx);
// Operation* op = rewriter.Build<paddle::dialect::FusedGemmEpilogueOp>(
// ir_values[0].dyn_cast<ir::OpResult>(),
// ir_values[1].dyn_cast<ir::OpResult>(),
// ir_values[2].dyn_cast<ir::OpResult>(),
// CreateAttributeMap(op_call, src_match_ctx));
// BindIrOutputs(op_call, op, res_match_ctx);
// } else if (op_call.name() == "pd.fused_gemm_epilogue_grad") {
// const auto& inputs = op_call.inputs();
// std::vector<Value> ir_values =
// GetIrValuesByDrrTensors(inputs, *res_match_ctx);
// op = rewriter.Build<paddle::dialect::FusedGemmEpilogueGradOp>(
// ir_values[0].dyn_cast<ir::OpResult>(),
// ir_values[1].dyn_cast<ir::OpResult>(),
// ir_values[2].dyn_cast<ir::OpResult>(),
// ir_values[3].dyn_cast<ir::OpResult>(),
// CreateAttributeMap(op_call, src_match_ctx));
// BindIrOutputs(op_call, op, res_match_ctx);
} else if (op_call.name() == "pd.fused_gemm_epilogue") {
const auto& inputs = op_call.inputs();
std::vector<Value> ir_values =
GetIrValuesByDrrTensors(inputs, *res_match_ctx);
Operation* op = rewriter.Build<paddle::dialect::FusedGemmEpilogueOp>(
ir_values[0].dyn_cast<ir::OpResult>(),
ir_values[1].dyn_cast<ir::OpResult>(),
ir_values[2].dyn_cast<ir::OpResult>(),
CreateAttributeMap(op_call, src_match_ctx));
BindIrOutputs(op_call, op, res_match_ctx);
} else if (op_call.name() == "pd.fused_gemm_epilogue_grad") {
const auto& inputs = op_call.inputs();
std::vector<Value> ir_values =
GetIrValuesByDrrTensors(inputs, *res_match_ctx);
op = rewriter.Build<paddle::dialect::FusedGemmEpilogueGradOp>(
ir_values[0].dyn_cast<ir::OpResult>(),
ir_values[1].dyn_cast<ir::OpResult>(),
ir_values[2].dyn_cast<ir::OpResult>(),
ir_values[3].dyn_cast<ir::OpResult>(),
CreateAttributeMap(op_call, src_match_ctx));
BindIrOutputs(op_call, op, res_match_ctx);
} else if (op_call.name() == "builtin.combine") {
const auto& inputs = op_call.inputs();
std::vector<Value> ir_values =
Expand Down
13 changes: 8 additions & 5 deletions paddle/fluid/ir/drr/ir_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,14 @@ class IrValue : public TensorInterface {
public:
explicit IrValue(const ir::Value& value)
: value_(value),
shape_(
&value.type().dyn_cast<paddle::dialect::DenseTensorType>().dims()),
dtype_(&value.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dtype()) {}
shape_(value ? &value.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims()
: nullptr),
dtype_(value ? &value.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dtype()
: nullptr) {}

ShapeInterface Shape() const override { return ShapeInterface(&shape_); }
DtypeInterface Dtype() const override { return DtypeInterface(&dtype_); }
Expand Down
7 changes: 6 additions & 1 deletion paddle/ir/core/value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/ir/core/value.h"

#include <glog/logging.h>
#include <cstddef>

#include "paddle/ir/core/enforce.h"
Expand Down Expand Up @@ -273,7 +274,11 @@ uint32_t OpResultImpl::GetResultIndex() const {
return ir::dyn_cast<OpInlineResultImpl>(this)->GetResultIndex();
}

OpResultImpl::~OpResultImpl() { assert(use_empty()); }
OpResultImpl::~OpResultImpl() {
if (!use_empty()) {
LOG(ERROR) << owner()->name() << " operation destroyed but still has uses.";
}
}

ir::Operation *OpResultImpl::owner() const {
// For inline result, pointer offset index to obtain the address of op.
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/ir/pattern_rewrite/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ cc_test_old(pattern_rewrite_test SRCS pattern_rewrite_test.cc DEPS
${PATTERN_REWRITE_TEST_DEPS})

cc_test_old(drr_test SRCS drr_test.cc DEPS gtest drr)
# cc_test_old(drr_fuse_linear_test SRCS drr_fuse_linear_test.cc DEPS gtest drr)
cc_test_old(drr_fuse_linear_test SRCS drr_fuse_linear_test.cc DEPS gtest drr)

cc_test_old(drr_attention_fuse_test SRCS drr_attention_fuse_test.cc DEPS
${PATTERN_REWRITE_TEST_DEPS} drr)
Expand Down
Loading

0 comments on commit 0833a06

Please sign in to comment.