Skip to content

Commit adab3b2

Browse files
committed
public gelu
1 parent a5b7556 commit adab3b2

23 files changed

+522
-2
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
op {
2+
graph_op_name: "Gelu"
3+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
op {
2+
graph_op_name: "GeluGrad"
3+
}

tensorflow/core/kernels/BUILD

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4718,6 +4718,7 @@ cc_library(
47184718
":depthwise_conv_op",
47194719
":dilation_ops",
47204720
":fused_batch_norm_op",
4721+
":gelu_op",
47214722
":in_topk_op",
47224723
":l2loss_op",
47234724
":lrn_op",
@@ -4803,6 +4804,12 @@ tf_kernel_library(
48034804
deps = NN_DEPS + if_rocm([":conv_ops_gpu_hdrs"]),
48044805
)
48054806

4807+
tf_kernel_library(
4808+
name = "gelu_op",
4809+
prefix = "gelu_op",
4810+
deps = NN_DEPS,
4811+
)
4812+
48064813
tf_kernel_library(
48074814
name = "relu_op",
48084815
prefix = "relu_op",

tensorflow/core/kernels/gelu_op.cc

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
// See docs in ../ops/nn_ops.cc.
17+
18+
#define EIGEN_USE_THREADS
19+
20+
#include "tensorflow/core/kernels/gelu_op.h"
21+
#include "tensorflow/core/framework/numeric_op.h"
22+
#include "tensorflow/core/framework/op_kernel.h"
23+
#include "tensorflow/core/framework/register_types.h"
24+
#include "tensorflow/core/lib/core/errors.h"
25+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26+
27+
namespace tensorflow {
28+
29+
typedef Eigen::ThreadPoolDevice CPUDevice;
30+
typedef Eigen::GpuDevice GPUDevice;
31+
32+
template <typename Device, typename T>
33+
class GeluOp : public UnaryElementWiseOp<T, GeluOp<Device, T>> {
34+
public:
35+
explicit GeluOp(OpKernelConstruction* context)
36+
: UnaryElementWiseOp<T, GeluOp<Device, T>>(context) {
37+
OP_REQUIRES_OK(context, context->GetAttr("approximate", &approximate_));
38+
}
39+
40+
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
41+
functor::Gelu<Device, T> functor;
42+
functor(context->eigen_device<Device>(), input.flat<T>(), approximate_,
43+
output->flat<T>());
44+
}
45+
46+
private:
47+
bool approximate_;
48+
};
49+
50+
template <typename Device, typename T>
51+
class GeluGradOp : public BinaryElementWiseOp<T, GeluGradOp<Device, T>> {
52+
public:
53+
explicit GeluGradOp(OpKernelConstruction* context)
54+
: BinaryElementWiseOp<T, GeluGradOp<Device, T>>(context) {
55+
OP_REQUIRES_OK(context, context->GetAttr("approximate", &approximate_));
56+
}
57+
58+
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
59+
const Tensor& a, Tensor* output);
60+
// INPUTS:
61+
// g (gradients): backpropagated gradients.
62+
// a (inputs): inputs that were passed to GeluOp().
63+
// OUTPUT:
64+
// gradients to backprop.
65+
template <int NDIMS>
66+
void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
67+
Tensor* output) {
68+
OperateNoTemplate(context, g, a, output);
69+
}
70+
71+
private:
72+
bool approximate_;
73+
};
74+
75+
template <typename Device, typename T>
76+
void GeluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
77+
const Tensor& g, const Tensor& a,
78+
Tensor* output) {
79+
OP_REQUIRES(context, a.IsSameSize(g),
80+
errors::InvalidArgument("g and a must be the same size"));
81+
functor::GeluGrad<Device, T> functor;
82+
functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(),
83+
approximate_, output->flat<T>());
84+
}
85+
86+
#define REGISTER_KERNELS(type) \
87+
REGISTER_KERNEL_BUILDER( \
88+
Name("Gelu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
89+
GeluOp<CPUDevice, type>); \
90+
REGISTER_KERNEL_BUILDER( \
91+
Name("GeluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
92+
GeluGradOp<CPUDevice, type>);
93+
94+
TF_CALL_GPU_NUMBER_TYPES(REGISTER_KERNELS);
95+
#undef REGISTER_KERNELS
96+
97+
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
98+
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
99+
// Forward declarations of the functor specializations for GPU.
100+
namespace functor {
101+
#define DECLARE_GPU_SPEC(T) \
102+
template <> \
103+
void Gelu<GPUDevice, T>::operator()( \
104+
const GPUDevice& d, typename TTypes<T>::ConstTensor features, \
105+
bool approximate, typename TTypes<T>::Tensor activations); \
106+
extern template struct Gelu<GPUDevice, T>; \
107+
\
108+
template <> \
109+
void GeluGrad<GPUDevice, T>::operator()( \
110+
const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
111+
typename TTypes<T>::ConstTensor features, bool approximate, \
112+
typename TTypes<T>::Tensor backprops); \
113+
extern template struct GeluGrad<GPUDevice, T>;
114+
115+
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
116+
} // namespace functor
117+
118+
// Registration of the GPU implementations.
119+
#define REGISTER_GPU_KERNELS(type) \
120+
REGISTER_KERNEL_BUILDER( \
121+
Name("Gelu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
122+
GeluOp<GPUDevice, type>); \
123+
REGISTER_KERNEL_BUILDER( \
124+
Name("GeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
125+
GeluGradOp<GPUDevice, type>);
126+
127+
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
128+
#undef REGISTER_GPU_KERNELS
129+
130+
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
131+
132+
} // namespace tensorflow

tensorflow/core/kernels/gelu_op.h

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_CORE_KERNELS_GELU_OP_H_
17+
#define TENSORFLOW_CORE_KERNELS_GELU_OP_H_
18+
19+
#include "tensorflow/core/framework/tensor_types.h"
20+
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21+
22+
namespace tensorflow {
23+
24+
namespace internal {
25+
constexpr double kCoeff = 0.044715;
26+
constexpr double kSqrtHalf = 0.7071067811865476;
27+
constexpr double kTwoRsqrtPi = 1.1283791670955126;
28+
constexpr double kAlpha = kSqrtHalf * kTwoRsqrtPi;
29+
} // namespace internal
30+
31+
namespace functor {
32+
33+
// Functor used by GeluOp to do the computations.
34+
template <typename Device, typename T>
35+
struct Gelu {
36+
// Computes Gelu activation.
37+
//
38+
// features: any shape.
39+
// approximate: whether to enable approximation.
40+
// activations: same shape as "features".
41+
void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
42+
bool approximate, typename TTypes<T>::Tensor activations) {
43+
const T one = static_cast<T>(1);
44+
const T half = static_cast<T>(0.5);
45+
if (approximate) {
46+
// y = 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))
47+
activations.device(d) =
48+
half * features *
49+
(one +
50+
(static_cast<T>(internal::kAlpha) *
51+
(features + static_cast<T>(internal::kCoeff) * features.cube()))
52+
.tanh());
53+
} else {
54+
// y = x * normcdf(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
55+
activations.device(d) =
56+
half * features *
57+
(one + (features * static_cast<T>(internal::kSqrtHalf)).erf());
58+
}
59+
}
60+
};
61+
62+
// Functor used by GeluGradOp to do the computations.
63+
template <typename Device, typename T>
64+
struct GeluGrad {
65+
// Computes GeluGrad backprops.
66+
//
67+
// gradients: gradients backpropagated to the Gelu op.
68+
// features: inputs that were passed to the Gelu op.
69+
// approximate: whether to enable approximation.
70+
// backprops: gradients to backpropagate to the Gelu inputs.
71+
void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
72+
typename TTypes<T>::ConstTensor features, bool approximate,
73+
typename TTypes<T>::Tensor backprops) {
74+
const T one = static_cast<T>(1);
75+
const T half = static_cast<T>(0.5);
76+
if (approximate) {
77+
const T kBeta = static_cast<T>(internal::kAlpha) *
78+
static_cast<T>(internal::kCoeff) * static_cast<T>(3);
79+
const auto y =
80+
(static_cast<T>(internal::kAlpha) *
81+
((static_cast<T>(internal::kCoeff) * features.cube()) + features))
82+
.tanh();
83+
backprops.device(d) =
84+
((-features * y.square() + features) *
85+
(kBeta * features.square() + static_cast<T>(internal::kAlpha)) +
86+
one + y) *
87+
gradients * half;
88+
} else {
89+
backprops.device(d) =
90+
gradients *
91+
(static_cast<T>(internal::kAlpha * 0.5) * features *
92+
(-features.square() * half).exp() +
93+
(half * (one + (features * static_cast<T>(internal::kSqrtHalf)).erf())));
94+
}
95+
}
96+
};
97+
98+
} // namespace functor
99+
} // namespace tensorflow
100+
101+
#endif // TENSORFLOW_CORE_KERNELS_GELU_OP_H_
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
17+
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
18+
19+
#define EIGEN_USE_GPU
20+
21+
#include "tensorflow/core/kernels/gelu_op.h"
22+
23+
#include "tensorflow/core/framework/register_types.h"
24+
#include "tensorflow/core/framework/tensor_types.h"
25+
26+
namespace tensorflow {
27+
28+
typedef Eigen::GpuDevice GPUDevice;
29+
30+
// Definition of the GPU implementations declared in gelu_op.cc.
31+
#define DEFINE_GPU_KERNELS(T) \
32+
template struct functor::Gelu<GPUDevice, T>; \
33+
template struct functor::GeluGrad<GPUDevice, T>;
34+
35+
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
36+
37+
} // namespace tensorflow
38+
39+
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

tensorflow/core/ops/nn_ops.cc

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,21 @@ REGISTER_OP("Dilation2DBackpropFilter")
10701070

10711071
// --------------------------------------------------------------------------
10721072

1073+
REGISTER_OP("Gelu")
1074+
.Input("features: T")
1075+
.Output("activations: T")
1076+
.Attr("T: {half, float, double}")
1077+
.Attr("approximate: bool = true")
1078+
.SetShapeFn(shape_inference::UnchangedShape);
1079+
1080+
REGISTER_OP("GeluGrad")
1081+
.Input("gradients: T")
1082+
.Input("features: T")
1083+
.Output("backprops: T")
1084+
.Attr("T: {half, float, double}")
1085+
.Attr("approximate: bool = true")
1086+
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1087+
10731088
REGISTER_OP("Relu")
10741089
.Input("features: T")
10751090
.Output("activations: T")

tensorflow/python/eager/pywrap_gradient_exclusions.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
410410

411411
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
412412
const tensorflow::string &op_name) {
413-
static std::array<OpIndexInfo, 459> a = {{
413+
static std::array<OpIndexInfo, 460> a = {{
414414
{"Abs"},
415415
{"AccumulateNV2"},
416416
{"Acos"},
@@ -539,6 +539,7 @@ absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
539539
{"Gather"},
540540
{"GatherNd"},
541541
{"GatherV2"},
542+
{"Gelu"},
542543
{"GenerateBoundingBoxProposals"},
543544
{"GenerateVocabRemapping"},
544545
{"GetSessionHandle"},

tensorflow/python/keras/activations.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,26 @@ def softmax(x, axis=-1):
8181
'Received input: %s' % (x,))
8282

8383

84+
@keras_export('keras.activations.gelu')
85+
def gelu(x, approximate=True):
86+
"""Gaussian Error Linear Unit.
87+
88+
Arguments:
89+
x: Input tensor.
90+
91+
Returns:
92+
The gaussian error linear activation:
93+
`0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))`
94+
if `approximate` is `True` or
95+
`x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2)))`, where P(X) ~ N(0, 1),
96+
if `approximate` is `False`.
97+
98+
Reference:
99+
- [Gaussian Error Linear Units (GELUs)](https://arxiv.org/abs/1606.08415)
100+
"""
101+
return nn.gelu(x, approximate)
102+
103+
84104
@keras_export('keras.activations.elu')
85105
def elu(x, alpha=1.0):
86106
"""Exponential linear unit.

0 commit comments

Comments
 (0)