Skip to content

mkl gelu #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions tensorflow/core/graph/mkl_layout_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.fused_depthwise_conv2d = "_FusedDepthwiseConv2dNative";
csinfo_.fused_matmul = "_FusedMatMul";
csinfo_.fused_matmul_grad = "_FusedMatMulGrad";
csinfo_.gelu = "Gelu";
csinfo_.gelu_grad = "GeluGrad";
csinfo_.identity = "Identity";
csinfo_.leakyrelu = "LeakyRelu";
csinfo_.leakyrelu_grad = "LeakyReluGrad";
Expand Down Expand Up @@ -500,6 +502,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});

rinfo_.push_back({csinfo_.gelu, mkl_op_registry::GetMklOpName(csinfo_.gelu),
CopyAttrsAll, GeluRewrite, kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.gelu_grad,
mkl_op_registry::GetMklOpName(csinfo_.gelu_grad),
CopyAttrsAll, GeluRewrite, kRewriteForLayoutPropagation});
rinfo_.push_back({csinfo_.identity,
mkl_op_registry::GetMklOpName(csinfo_.identity),
CopyAttrsAll, RewriteIfAtleastOneMklInput,
Expand Down Expand Up @@ -947,6 +954,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string fused_depthwise_conv2d;
string fused_matmul;
string fused_matmul_grad;
string gelu;
string gelu_grad;
string identity;
string leakyrelu;
string leakyrelu_grad;
Expand Down Expand Up @@ -1551,6 +1560,27 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return false;
}

// MKL-DNN's Gelu only support approximate version,
// so we only rewrite Gelu to MKL OP when approximate is true
static bool GeluRewrite(const Node* n) {
DCHECK(n);

bool approximate = false;
bool has_attr = TryGetNodeAttr(n->def(), "approximate", &approximate);
DCHECK(has_attr);

// If approximate is true, rewrite the node.
// Otherwise eigen node is used instead.
if (approximate) {
return true;
}
VLOG(1) << "GeluRewrite: The model sets approximate is false "
<< "which case is not optimized by Intel MKL, thus using Eigen op"
<< "for Gelu ";

return false;
}

// If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized
// path. The unoptimized path is slow. Thus we dont rewrite the node
// and use default Eigen. But for depth_radius=2, MKL DNN optimized
Expand Down
65 changes: 65 additions & 0 deletions tensorflow/core/graph/mkl_layout_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2876,6 +2876,71 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Relu6Grad_Positive) {
"A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
}

#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT \
"'}" \
"node { name: 'B' op: 'Gelu'" \
" attr { key: 'T' value { type: " #T \
" } }" \
" attr { key: 'approximate' value { b: true } }" \
" input: ['A']}" \
"node { name: 'C' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T \
" } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT \
");B(_MklGelu);C(Zeta);DMT/_0(Const)|A->B;A->C;" \
"A:control->DMT/_0:control;B->C:1;DMT/_0->B:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_Gelu_Positive);
#undef REGISTER_TEST

#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT \
"'}" \
"node { name: 'B' op: 'Gelu'" \
" attr { key: 'T' value { type: " #T \
" } }" \
" attr { key: 'approximate' value { b: false } }" \
" input: ['A']}" \
"node { name: 'C' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T \
" } }" \
" input: ['A', 'B'] }"); \
EXPECT_EQ(DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(Gelu);C(Zeta)|A->B;A->C;B->C:1"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_Gelu_Negative);
#undef REGISTER_TEST

#define REGISTER_TEST(NAME, T, INPUT) \
TEST_F(MklLayoutPassTest, NAME##_##T) { \
InitGraph("node { name: 'A' op: '" #INPUT \
"'}" \
"node { name: 'B' op: '" #INPUT \
"'}" \
"node { name: 'C' op: 'GeluGrad'" \
" attr { key: 'T' value { type: " #T \
" } }" \
" attr { key: 'approximate' value { b: true } }" \
" input: ['A', 'B']}" \
"node { name: 'D' op: 'Zeta'" \
"attr { key: 'T' value { type: " #T \
" } }" \
" input: ['A', 'C'] }"); \
EXPECT_EQ( \
DoMklLayoutOptimizationPass(), \
"A(" #INPUT ");B(" #INPUT \
");C(_MklGeluGrad);D(Zeta);DMT/_0(Const);" \
"DMT/_1(Const)|A->C;A->D;A:control->DMT/_0:control;" \
"A:control->DMT/_1:control;B->C:1;C->D:1;DMT/_0->C:2;DMT/_1->C:3"); \
}
REGISTER_TEST_ALL_TYPES(NodeRewrite_GeluGrad_Positive);
#undef REGISTER_TEST

TEST_F(MklLayoutPassTest, NodeRewrite_Relu6Relu6Grad_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8165,6 +8165,7 @@ tf_mkl_kernel_library(
prefix = "mkl_relu",
deps = [
":bounds_check",
":no_op",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
Expand Down
119 changes: 118 additions & 1 deletion tensorflow/core/kernels/mkl_relu_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ limitations under the License.
#include <unordered_map>

#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_types.h"
#include "tensorflow/core/util/mkl_util.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

using mkldnn::algorithm;
using mkldnn::eltwise_forward;
Expand Down Expand Up @@ -1163,6 +1164,101 @@ class MklLeakyReluGradOp
}
};

template <typename Device, typename T>
class MklGeluOp : public MklReluOpBase<Device, T, ALGORITHM::eltwise_gelu> {
public:
~MklGeluOp() {}

explicit MklGeluOp(OpKernelConstruction* context)
: MklReluOpBase<Device, T, ALGORITHM::eltwise_gelu>(context, 0.0f, 0.0f) {
bool approximate;
OP_REQUIRES_OK(context, context->GetAttr("approximate", &approximate));
OP_REQUIRES(
context, approximate,
errors::InvalidArgument("MKL Gelu only supports approximate is true. "
"approximate is: ",
approximate));
}

virtual void Compute_Scalar(OpKernelContext* context) {
const size_t src_index = 0; // index of src input tensor
const size_t dst_index = 0; // index of dst output tensor
const Tensor& src_tensor = MklGetInput(context, src_index);
MklDnnShape dnn_shape_src;
GetMklShape(context, src_index, &dnn_shape_src);

Tensor* dst_tensor = nullptr;
T* user_i = const_cast<T*>(src_tensor.flat<T>().data());
MklDnnShape dnn_shape_dst;
dnn_shape_dst.SetMklTensor(false);
AllocateOutputSetMklShape(context, dst_index, &dst_tensor,
src_tensor.shape(), dnn_shape_dst);

T* out_o = dst_tensor->flat<T>().data();
T features = user_i[0];
out_o[0] =
static_cast<T>(0.5) * features *
(static_cast<T>(1) +
std::tanh(static_cast<T>(M_2_SQRTPI * M_SQRT1_2) *
(features + static_cast<T>(0.044715) *
std::pow(features, static_cast<T>(3)))));
return;
}
};

template <typename Device, typename T>
class MklGeluGradOp
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_gelu> {
public:
~MklGeluGradOp() {}

explicit MklGeluGradOp(OpKernelConstruction* context)
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_gelu>(context, 0.0f,
0.0f) {
bool approximate;
OP_REQUIRES_OK(context, context->GetAttr("approximate", &approximate));
OP_REQUIRES(
context, approximate,
errors::InvalidArgument("MKL Gelu only supports approximate is true. "
"approximate is: ",
approximate));
}

virtual void Compute_Scalar(OpKernelContext* context) {
const size_t diff_dst_index = 0; // index of diff_dst input tensor
const size_t src_index = 1; // index of src input tensor
const size_t diff_src_index = 0; // index of diff_src output tensor
const Tensor& src_tensor = MklGetInput(context, src_index);
const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index);
Tensor* diff_src_tensor = nullptr;

MklDnnShape dnn_shape_diff_dst;
GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst);

MklDnnShape dnn_shape_diff_src;
dnn_shape_diff_src.SetMklTensor(false);
AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
diff_dst_tensor.shape(), dnn_shape_diff_src);
T* out_o = diff_src_tensor->flat<T>().data();
T* user_i = const_cast<T*>(src_tensor.flat<T>().data());
T* user_g = const_cast<T*>(diff_dst_tensor.flat<T>().data());

T features = user_i[0];
const T kAlpha = static_cast<T>(M_2_SQRTPI * M_SQRT1_2);
const T kBeta = kAlpha * static_cast<T>(0.044715) * static_cast<T>(3);
const auto y = std::tanh(
(kAlpha *
((static_cast<T>(0.044715) * std::pow(features, static_cast<T>(3))) +
features)));
out_o[0] = user_g[0] * static_cast<T>(0.5) *
((-features * y * y + features) *
(kBeta * features * features + kAlpha) +
static_cast<T>(1) + y);

return;
}
};

// register dnn kernels for supported operations and supported types
#define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \
REGISTER_KERNEL_BUILDER( \
Expand Down Expand Up @@ -1245,6 +1341,27 @@ TF_CALL_bfloat16(REGISTER_RELU6_MKL_SUPPORTED_KERNELS_TYPES);
TF_CALL_float(REGISTER_LeakyRelu_MKL_SUPPORTED_KERNELS_TYPES);
TF_CALL_bfloat16(REGISTER_LeakyRelu_MKL_SUPPORTED_KERNELS_TYPES);

#define REGISTER_GELU_MKL_SUPPORTED_KERNELS_TYPES(type) \
REGISTER_KERNEL_BUILDER( \
Name("_MklGelu") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklGeluOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklGeluGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklGeluGradOp<CPUDevice, type>);
TF_CALL_float(REGISTER_GELU_MKL_SUPPORTED_KERNELS_TYPES);
TF_CALL_bfloat16(REGISTER_GELU_MKL_SUPPORTED_KERNELS_TYPES);

REGISTER_KERNEL_BUILDER(
Name("Gelu").Device(DEVICE_CPU).TypeConstraint<bfloat16>("T"), NoOp);
REGISTER_KERNEL_BUILDER(
Name("GeluGrad").Device(DEVICE_CPU).TypeConstraint<bfloat16>("T"), NoOp);

} // namespace tensorflow

#endif // INTEL_MKL
38 changes: 36 additions & 2 deletions tensorflow/core/ops/nn_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1077,15 +1077,15 @@ REGISTER_OP("Dilation2DBackpropFilter")
REGISTER_OP("Gelu")
.Input("features: T")
.Output("activations: T")
.Attr("T: {half, float, double}")
.Attr("T: {half, float, double, bfloat16}")
.Attr("approximate: bool = true")
.SetShapeFn(shape_inference::UnchangedShape);

REGISTER_OP("GeluGrad")
.Input("gradients: T")
.Input("features: T")
.Output("backprops: T")
.Attr("T: {half, float, double}")
.Attr("T: {half, float, double, bfloat16}")
.Attr("approximate: bool = true")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);

Expand Down Expand Up @@ -2209,6 +2209,40 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");

REGISTER_OP("_MklGelu")
.Input("features: T")
.Input("mkl_features: uint8")
.Output("activations: T")
.Output("mkl_activations: uint8")
.Attr("T: {float, bfloat16} = DT_FLOAT")
.Attr("approximate: bool = true")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
MKL version of Gelu operator. Uses MKL DNN APIs to implement
Gelu operator.

NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");

REGISTER_OP("_MklGeluGrad")
.Input("gradients: T")
.Input("features: T")
.Input("mkl_gradients: uint8")
.Input("mkl_features: uint8")
.Output("backprops: T")
.Output("mkl_backprops: uint8")
.Attr("T: {float, bfloat16} = DT_FLOAT")
.Attr("approximate: bool = true")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn)
.Doc(R"doc(
MKL version of GeluGrad operator. Uses MKL DNN APIs to compute the
gradients for GeluGrad operation.

NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");

REGISTER_OP("_MklElu")
.Input("features: T")
.Input("mkl_features: uint8")
Expand Down