@@ -3868,12 +3868,22 @@ template <typename T>
38683868struct CudaTanGradFunctor : public BaseActivationFunctor <T> {
38693869 using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
38703870
3871- // dx = dout / cos (x)^2
3871+ // dx = dout *(1 + tan (x)^2)
38723872 __device__ __forceinline__ T operator ()(const T arg_dout,
38733873 const T arg_x) const {
38743874 MPType dout = static_cast <MPType>(arg_dout);
38753875 MPType x = static_cast <MPType>(arg_x);
3876- return static_cast <T>(dout / (cos (x) * cos (x)));
3876+ if constexpr (std::is_same<MPType, double >::value) {
3877+ double td = ::tan (x);
3878+ double tsq = __dmul_rn (td, td);
3879+ double y = __dadd_rn (tsq, 1.0 );
3880+ return static_cast <T>(dout * y);
3881+ } else {
3882+ float tf = ::tanf (x);
3883+ float tsq = __fmul_rn (tf, tf);
3884+ float y = __fadd_rn (tsq, 1 .0f );
3885+ return static_cast <T>(dout * y);
3886+ }
38773887 }
38783888
38793889 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
@@ -3882,10 +3892,11 @@ struct CudaTanGradFunctor : public BaseActivationFunctor<T> {
38823892template <typename T>
38833893struct CudaTanGradFunctor <ComplexType<T>>
38843894 : public BaseActivationFunctor<ComplexType<T>> {
3885- // dx = dout / cos (x)^2
3895+ // dx = dout *(1 + tan (x)^2)
38863896 __device__ __forceinline__ ComplexType<T> operator ()(
38873897 const ComplexType<T> dout, const ComplexType<T> x) const {
3888- return static_cast <ComplexType<T>>(dout / conj (cos (x) * cos (x)));
3898+ ComplexType<T> one = static_cast <ComplexType<T>>(1 .0f );
3899+ return static_cast <ComplexType<T>>(dout * conj (tan (x) * tan (x) + one));
38893900 }
38903901
38913902 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
0 commit comments