@@ -4207,10 +4207,19 @@ struct CudaSTanhGradFunctor<ComplexType<T>>
42074207 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
42084208};
42094209
4210+ template <typename T>
4211+ __device__ __forceinline__ T log1p_local (T x) {
4212+ return log1p (x);
4213+ }
4214+
4215+ template <typename T>
4216+ __device__ __forceinline__ ComplexType<T> log1p_local (ComplexType<T> x) {
4217+ return log (ComplexType<T>{1 .} + exp (x));
4218+ }
4219+
42104220template <typename T>
42114221struct CudaSoftplusFunctor : public BaseActivationFunctor <T> {
42124222 using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
4213- MPType one = static_cast <MPType>(1 .0f );
42144223 float beta;
42154224 float threshold;
42164225
@@ -4223,8 +4232,7 @@ struct CudaSoftplusFunctor : public BaseActivationFunctor<T> {
42234232 MPType x = static_cast <MPType>(arg_x);
42244233 MPType b = static_cast <MPType>(beta);
42254234 MPType t = static_cast <MPType>(threshold);
4226- MPType x_beta = x * static_cast <MPType>(beta);
4227- return static_cast <T>(x_beta > t ? x : log (one + exp (x_beta)) / b);
4235+ return static_cast <T>((x * b) > t ? x : (log1p_local (exp (x * b))) / b);
42284236 }
42294237};
42304238
@@ -4246,8 +4254,8 @@ struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
42464254 MPType x = static_cast <MPType>(arg_x);
42474255 MPType b = static_cast <MPType>(beta);
42484256 MPType t = static_cast <MPType>(threshold);
4249- MPType x_beta = x * beta ;
4250- return x_beta > t ? arg_dout : static_cast <T>(dout / (one + exp (-x_beta) ));
4257+ MPType z = std::exp ( x * b) ;
4258+ return (x * b) > t ? arg_dout : static_cast <T>(dout * z / (z + one ));
42514259 }
42524260
42534261 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
@@ -4272,10 +4280,10 @@ struct CudaSoftplusGradFunctor<ComplexType<T>>
42724280 MPType x = static_cast <MPType>(arg_x);
42734281 MPType b = static_cast <MPType>(beta);
42744282 MPType t = static_cast <MPType>(threshold);
4275- MPType x_beta = x * static_cast <MPType>(beta );
4276- return x_beta > t
4283+ MPType z = exp ( x * b );
4284+ return (x * b) > t
42774285 ? dout
4278- : static_cast <ComplexType<T>>(dout / conj (one + exp (-x_beta )));
4286+ : static_cast <ComplexType<T>>(dout * conj (z / (z + one )));
42794287 }
42804288
42814289 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
0 commit comments