Skip to content

Commit 30b7970

Browse files
[Precision Depth Alignment] paddle.sin and paddle.cos aligns with torch precision. (PaddlePaddle#75503)
* accuracy_stable_sin * accuracy_stable_cos
1 parent 0c90043 commit 30b7970

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

paddle/phi/kernels/funcs/activation_functor.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3500,7 +3500,11 @@ struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
35003500
const T arg_x) const {
35013501
MPType dout = static_cast<MPType>(arg_dout);
35023502
MPType x = static_cast<MPType>(arg_x);
3503-
return static_cast<T>(-dout * sin(x));
3503+
if constexpr (std::is_same<T, phi::float16>::value) {
3504+
return static_cast<T>(-arg_dout * static_cast<T>(sin(x)));
3505+
} else {
3506+
return static_cast<T>(-dout * sin(x));
3507+
}
35043508
}
35053509

35063510
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
@@ -3835,7 +3839,11 @@ struct CudaSinGradFunctor : public BaseActivationFunctor<T> {
38353839
const T arg_x) const {
38363840
MPType dout = static_cast<MPType>(arg_dout);
38373841
MPType x = static_cast<MPType>(arg_x);
3838-
return static_cast<T>(dout * cos(x));
3842+
if constexpr (std::is_same<T, phi::float16>::value) {
3843+
return static_cast<T>(arg_dout * static_cast<T>(cos(x)));
3844+
} else {
3845+
return static_cast<T>(dout * cos(x));
3846+
}
38393847
}
38403848

38413849
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }

0 commit comments

Comments
 (0)