Skip to content

Commit 15ed987

Browse files
[Precision Depth Alignment] paddle.tan reverse calculation: dx = dout *(1 + tan(x)^2) (PaddlePaddle#75335)
* Tan reverse calculation: dx = dout *(1 + tan(x)^2)
1 parent d5ff262 commit 15ed987

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

paddle/phi/kernels/funcs/activation_functor.h

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3868,12 +3868,22 @@ template <typename T>
38683868
struct CudaTanGradFunctor : public BaseActivationFunctor<T> {
38693869
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
38703870

3871-
// dx = dout / cos(x)^2
3871+
// dx = dout *(1 + tan(x)^2)
38723872
__device__ __forceinline__ T operator()(const T arg_dout,
38733873
const T arg_x) const {
38743874
MPType dout = static_cast<MPType>(arg_dout);
38753875
MPType x = static_cast<MPType>(arg_x);
3876-
return static_cast<T>(dout / (cos(x) * cos(x)));
3876+
if constexpr (std::is_same<MPType, double>::value) {
3877+
double td = ::tan(x);
3878+
double tsq = __dmul_rn(td, td);
3879+
double y = __dadd_rn(tsq, 1.0);
3880+
return static_cast<T>(dout * y);
3881+
} else {
3882+
float tf = ::tanf(x);
3883+
float tsq = __fmul_rn(tf, tf);
3884+
float y = __fadd_rn(tsq, 1.0f);
3885+
return static_cast<T>(dout * y);
3886+
}
38773887
}
38783888

38793889
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
@@ -3882,10 +3892,11 @@ struct CudaTanGradFunctor : public BaseActivationFunctor<T> {
38823892
template <typename T>
38833893
struct CudaTanGradFunctor<ComplexType<T>>
38843894
: public BaseActivationFunctor<ComplexType<T>> {
3885-
// dx = dout / cos(x)^2
3895+
// dx = dout *(1 + tan(x)^2)
38863896
__device__ __forceinline__ ComplexType<T> operator()(
38873897
const ComplexType<T> dout, const ComplexType<T> x) const {
3888-
return static_cast<ComplexType<T>>(dout / conj(cos(x) * cos(x)));
3898+
ComplexType<T> one = static_cast<ComplexType<T>>(1.0f);
3899+
return static_cast<ComplexType<T>>(dout * conj(tan(x) * tan(x) + one));
38893900
}
38903901

38913902
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }

0 commit comments

Comments
 (0)