File tree Expand file tree Collapse file tree 1 file changed +10
-2
lines changed Expand file tree Collapse file tree 1 file changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -3500,7 +3500,11 @@ struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
35003500 const T arg_x) const {
35013501 MPType dout = static_cast <MPType>(arg_dout);
35023502 MPType x = static_cast <MPType>(arg_x);
3503- return static_cast <T>(-dout * sin (x));
3503+ if constexpr (std::is_same<T, phi::float16>::value) {
3504+ return static_cast <T>(-arg_dout * static_cast <T>(sin (x)));
3505+ } else {
3506+ return static_cast <T>(-dout * sin (x));
3507+ }
35043508 }
35053509
35063510 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
@@ -3835,7 +3839,11 @@ struct CudaSinGradFunctor : public BaseActivationFunctor<T> {
38353839 const T arg_x) const {
38363840 MPType dout = static_cast <MPType>(arg_dout);
38373841 MPType x = static_cast <MPType>(arg_x);
3838- return static_cast <T>(dout * cos (x));
3842+ if constexpr (std::is_same<T, phi::float16>::value) {
3843+ return static_cast <T>(arg_dout * static_cast <T>(cos (x)));
3844+ } else {
3845+ return static_cast <T>(dout * cos (x));
3846+ }
38393847 }
38403848
38413849 static constexpr ActBwdOpFwdDeps FwdDeps () { return ActBwdOpFwdDeps::kDepX ; }
You can’t perform that action at this time.
0 commit comments