Skip to content

Commit

Permalink
gelu using normcdf for cudnn (PaddlePaddle#38450)
Browse files Browse the repository at this point in the history
  • Loading branch information
GuoxiaWang authored Dec 27, 2021
1 parent 5d90295 commit 3702248
Showing 1 changed file with 5 additions and 11 deletions.
16 changes: 5 additions & 11 deletions paddle/fluid/operators/gelu_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@ struct GeluWithoutApproximateFunctor {
inline HOSTDEVICE T operator()(T arg_x) {
// actual gelu with approximation = false
MPType x = static_cast<MPType>(arg_x);
MPType one = static_cast<MPType>(1);
MPType half = static_cast<MPType>(0.5);
MPType erf_out = erf(x * static_cast<MPType>(M_SQRT1_2));
MPType out = x * half * (one + erf_out);
return static_cast<T>(out);
return static_cast<T>(x * normcdf(x));
}
};

Expand Down Expand Up @@ -100,12 +96,10 @@ struct GeluWithoutApproximateGradFunctor {
inline HOSTDEVICE T operator()(T arg_x, T arg_dout) {
MPType x = static_cast<MPType>(arg_x);
MPType dout = static_cast<MPType>(arg_dout);
MPType one = static_cast<MPType>(1);
MPType half = static_cast<MPType>(0.5);
MPType kAlpha = static_cast<MPType>(M_2_SQRTPI * M_SQRT1_2);
auto ans = half * (one + erf(x * static_cast<MPType>(M_SQRT1_2))) +
half * kAlpha * x * exp(-half * x * x);
return static_cast<T>(ans * dout);
constexpr MPType kBeta = M_2_SQRTPI * M_SQRT1_2 * static_cast<MPType>(0.5);
const MPType cdf = normcdf(x);
const MPType pdf = exp(static_cast<MPType>(-0.5) * x * x) * kBeta;
return static_cast<T>(dout * (cdf + x * pdf));
}
};

Expand Down

0 comments on commit 3702248

Please sign in to comment.