Skip to content

Commit f97f6c8

Browse files
accuracy_stable_sqrt (#75367)
paddle.sqrt and paddle.rsqrt accuracy and torch alignment
1 parent 2a60a6a commit f97f6c8

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

paddle/phi/kernels/funcs/activation_functor.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4330,11 +4330,11 @@ struct CudaSqrtFunctor : public BaseActivationFunctor<T> {
43304330

43314331
template <typename T>
43324332
struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
4333-
T one_half = static_cast<T>(0.5f);
4333+
T two = static_cast<T>(2);
43344334

4335-
// dx = dout * 0.5 / out
4335+
// dx = dout / (2 * out)
43364336
__device__ __forceinline__ T operator()(const T dout, const T out) const {
4337-
return one_half * dout / out;
4337+
return dout / (two * out);
43384338
}
43394339

43404340
static constexpr ActBwdOpFwdDeps FwdDeps() {
@@ -4421,7 +4421,7 @@ struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
44214421
const T arg_out) const {
44224422
MPType dout = static_cast<MPType>(arg_dout);
44234423
MPType out = static_cast<MPType>(arg_out);
4424-
return static_cast<T>(minus_one_half * dout * out * out * out);
4424+
return static_cast<T>(minus_one_half * dout * (out * out * out));
44254425
}
44264426

44274427
static constexpr ActBwdOpFwdDeps FwdDeps() {

0 commit comments

Comments
 (0)