Skip to content

Add MatMul and Gelu forward fusion for MKLDNN backend. #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorflow/core/grappler/op_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ bool IsGather(const NodeDef& node) {
return op == "Gather" || op == "GatherV2";
}

bool IsGelu(const NodeDef& node) { return node.op() == "Gelu"; }

bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; }

bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; }
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/grappler/op_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ bool IsFusedBatchNorm(const NodeDef& node);
bool IsFusedBatchNormEx(const NodeDef& node);
bool IsFusedBatchNormGrad(const NodeDef& node);
bool IsGather(const NodeDef& node);
bool IsGelu(const NodeDef& node);
bool IsGreater(const NodeDef& node);
bool IsGreaterEqual(const NodeDef& node);
bool IsHistogramSummary(const NodeDef& node);
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/core/grappler/optimizers/remapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,11 @@ bool IsDeviceCompatible(const RemapperContext& ctx, Pattern& matched) {
}

bool IsSupportedActivation(const NodeDef& node) {
#ifdef INTEL_MKL
return IsRelu(node) || IsRelu6(node) || IsElu(node) || IsGelu(node);
#else
return IsRelu(node) || IsRelu6(node) || IsElu(node);
#endif
}

inline bool HasControlFaninOrFanout(const utils::MutableNodeView& node_view) {
Expand Down
18 changes: 17 additions & 1 deletion tensorflow/core/kernels/mkl_fused_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,12 @@ class MklFusedMatMulOpTest : public OpsTestBase {
next_op = ops::Elu(root.WithOpName(last_op), next_op);
}

if (std::find(fused_ops.begin(), fused_ops.end(), "Gelu") !=
fused_ops.end()) {
last_op = "with_gelu";
next_op = ops::Gelu(root.WithOpName(last_op), next_op);
}

CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
};

Expand Down Expand Up @@ -965,11 +971,21 @@ TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndElu) {
{"BiasAdd", "Elu"});
}

TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndGelu) {
const int batch = 3;
const int input_channel = 4;
const int output_channel = 5;

this->VerifyFusedMatMul(batch, input_channel, output_channel,
{"BiasAdd", "Gelu"});
}

REGISTER_TYPED_TEST_CASE_P(MklFusedMatMulOpTest, //
WithBias, //
WithBiasAndRelu, //
WithBiasAndRelu6, //
WithBiasAndElu);
WithBiasAndElu, //
WithBiasAndGelu);

using MklFusedMatMulDataTypes = ::testing::Types<float>;
INSTANTIATE_TYPED_TEST_CASE_P(Test, MklFusedMatMulOpTest,
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/core/kernels/mkl_matmul_op_fused.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
params.post_op_params.push_back({"relu6", {1.0, 6.0, 0.0}});
} else if (post_op == "Elu") {
params.post_op_params.push_back({"elu", {1.0, 1.0, 0.0}});
} else if (post_op == "Gelu") {
params.post_op_params.push_back({"gelu", {1.0, 1.0, 0.0}});
} else {
OP_REQUIRES_OK(
ctx, errors::InvalidArgument(
Expand Down
11 changes: 10 additions & 1 deletion tensorflow/core/kernels/mkl_matmul_ops_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,13 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
float op_beta = post_op_param.param[2];
post_ops.append_eltwise(op_scale, ALGORITHM::eltwise_elu, op_alpha,
op_beta);
} else if (post_op_param.name == "gelu") {
DCHECK_EQ(post_op_param.param.size(), 3);
float op_scale = post_op_param.param[0];
float op_alpha = post_op_param.param[1];
float op_beta = post_op_param.param[2];
post_ops.append_eltwise(op_scale, ALGORITHM::eltwise_gelu, op_alpha,
op_beta);
} else if (post_op_param.name == "output_scale") {
DCHECK_EQ(post_op_param.param.size(), 1);
std::vector<float> scales;
Expand All @@ -268,6 +275,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
DCHECK((post_op_param.name == "relu") ||
(post_op_param.name == "relu6") ||
(post_op_param.name == "elu") ||
(post_op_param.name == "gelu") ||
(post_op_param.name == "output_scale"));
}
}
Expand Down Expand Up @@ -372,11 +380,12 @@ class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.bias_dims);
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dst_dims);
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dtypes);
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_format);

// Generate keys for post-ops
for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) {
if (post_op_param.name == "relu" || post_op_param.name == "relu6" ||
post_op_param.name == "elu") {
post_op_param.name == "elu" || post_op_param.name == "gelu") {
DCHECK_EQ(post_op_param.param.size(), 3);
key_creator.AddAsKey(post_op_param.name);
key_creator.AddAsKey(post_op_param.param[0]);
Expand Down