Skip to content

Commit 7b3ac4b

Browse files
Softplus accuracy and torch alignment 1 (#75363)
1 parent c3a89b6 commit 7b3ac4b

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

paddle/phi/kernels/funcs/activation_functor.h

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4207,10 +4207,19 @@ struct CudaSTanhGradFunctor<ComplexType<T>>
42074207
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
42084208
};
42094209

4210+
template <typename T>
4211+
__device__ __forceinline__ T log1p_local(T x) {
4212+
return log1p(x);
4213+
}
4214+
4215+
template <typename T>
4216+
__device__ __forceinline__ ComplexType<T> log1p_local(ComplexType<T> x) {
4217+
return log(ComplexType<T>{1.} + exp(x));
4218+
}
4219+
42104220
template <typename T>
42114221
struct CudaSoftplusFunctor : public BaseActivationFunctor<T> {
42124222
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
4213-
MPType one = static_cast<MPType>(1.0f);
42144223
float beta;
42154224
float threshold;
42164225

@@ -4223,8 +4232,7 @@ struct CudaSoftplusFunctor : public BaseActivationFunctor<T> {
42234232
MPType x = static_cast<MPType>(arg_x);
42244233
MPType b = static_cast<MPType>(beta);
42254234
MPType t = static_cast<MPType>(threshold);
4226-
MPType x_beta = x * static_cast<MPType>(beta);
4227-
return static_cast<T>(x_beta > t ? x : log(one + exp(x_beta)) / b);
4235+
return static_cast<T>((x * b) > t ? x : (log1p_local(exp(x * b))) / b);
42284236
}
42294237
};
42304238

@@ -4246,8 +4254,8 @@ struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
42464254
MPType x = static_cast<MPType>(arg_x);
42474255
MPType b = static_cast<MPType>(beta);
42484256
MPType t = static_cast<MPType>(threshold);
4249-
MPType x_beta = x * beta;
4250-
return x_beta > t ? arg_dout : static_cast<T>(dout / (one + exp(-x_beta)));
4257+
MPType z = std::exp(x * b);
4258+
return (x * b) > t ? arg_dout : static_cast<T>(dout * z / (z + one));
42514259
}
42524260

42534261
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
@@ -4272,10 +4280,10 @@ struct CudaSoftplusGradFunctor<ComplexType<T>>
42724280
MPType x = static_cast<MPType>(arg_x);
42734281
MPType b = static_cast<MPType>(beta);
42744282
MPType t = static_cast<MPType>(threshold);
4275-
MPType x_beta = x * static_cast<MPType>(beta);
4276-
return x_beta > t
4283+
MPType z = exp(x * b);
4284+
return (x * b) > t
42774285
? dout
4278-
: static_cast<ComplexType<T>>(dout / conj(one + exp(-x_beta)));
4286+
: static_cast<ComplexType<T>>(dout * conj(z / (z + one)));
42794287
}
42804288

42814289
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }

0 commit comments

Comments
 (0)