Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 38 additions & 15 deletions paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,22 @@ QuantDequantFusePass::QuantDequantFusePass() {
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsBoolEQ(false)
.End()
.AddAttr("trans_y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
Expand Down Expand Up @@ -355,7 +371,8 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
quantized_op_type == "fc" ||
quantized_op_type == "conv2d_transpose") {
op_desc->SetAttr("Input_scale", scale_value);
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul") {
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "matmul_v2") {
op_desc->SetAttr("X_scale", scale_value);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
Expand Down Expand Up @@ -387,7 +404,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
quantized_op_type == "conv2d_transpose") {
weight_name = "Filter";
input_name = "Input";
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul") {
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "matmul_v2") {
weight_name = "Y";
input_name = "X";
} else if (quantized_op_type == "fc") {
Expand All @@ -396,7 +414,7 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"QuantDequantFuse: We only support conv2d, conv2d_fusion, "
"conv2d_transpose, fc, mul, matmul for "
"conv2d_transpose, fc, mul, matmul, matmul_v2 for "
"now."));
}
const std::string pattern_name = "dequant_fuse";
Expand Down Expand Up @@ -479,14 +497,14 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
// If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "fc") {
quantized_op_type == "matmul_v2" || quantized_op_type == "fc") {
if (dequant_type == "fake_dequantize_max_abs") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), 1,
platform::errors::InvalidArgument(
"mul/matmul op weight dequantized by [fake_dequantize_max_abs] "
"requires weight scale size = 1, but got %d.",
weight_scale.size()));
PADDLE_ENFORCE_EQ(weight_scale.size(), 1,
platform::errors::InvalidArgument(
"mul/matmul/matmul_v2 op weight dequantized by "
"[fake_dequantize_max_abs] "
"requires weight scale size = 1, but got %d.",
weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
quantized_weight_data[j] *= weight_scale[0];
}
Expand All @@ -497,17 +515,19 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
PADDLE_ENFORCE_EQ(
quant_axis == 1, true,
platform::errors::InvalidArgument(
"'quant_axis' of mul/matmul/fc op weight dequantized by "
"'quant_axis' of mul/matmul/fc/matmul_v2 op weight "
"dequantized by "
"[fake_channel_wise_dequantize_max_abs]should be 1, but "
"the received is %d",
quant_axis));
}
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument(
"mul/matmul op weight dequantized by "
"mul/matmul/matmul_v2 op weight dequantized by "
"[fake_channel_wise_dequantize_max_abs] requires weight scale "
"size = 2nd dim of mul/matmul's weight, which is %d, but got "
"size = 2nd dim of mul/matmul/matmul_v2's weight, which is %d, "
"but got "
"%d.",
static_cast<size_t>(w_dims[1]), weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
Expand Down Expand Up @@ -594,7 +614,8 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
} else if (quantized_op_type == "fc") {
new_op_desc.SetInput("Input", {new_input});
new_op_desc.SetOutput("Out", {new_output});
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul") {
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "matmul_v2") {
new_op_desc.SetInput("X", {new_input});
new_op_desc.SetOutput("Out", {new_output});
}
Expand All @@ -621,7 +642,9 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
std::unordered_set<std::string> quant_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::unordered_set<std::string> quantized_op_types = {
"conv2d", "mul", "matmul", "depthwise_conv2d", "fc", "conv2d_transpose"};
"conv2d", "mul", "matmul", "depthwise_conv2d",
"conv2d_transpose", "fc", "matmul_v2",
};
auto* scope = param_scope();

for (auto& quant_type : quant_types) {
Expand Down