File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -4330,11 +4330,11 @@ struct CudaSqrtFunctor : public BaseActivationFunctor<T> {
43304330
43314331template <typename T>
43324332struct CudaSqrtGradFunctor : public BaseActivationFunctor <T> {
4333- T one_half = static_cast <T>(0 . 5f );
4333+ T two = static_cast <T>(2 );
43344334
4335- // dx = dout * 0.5 / out
4335+ // dx = dout / (2 * out)
43364336 __device__ __forceinline__ T operator ()(const T dout, const T out) const {
4337- return one_half * dout / out;
4337+ return dout / (two * out) ;
43384338 }
43394339
43404340 static constexpr ActBwdOpFwdDeps FwdDeps () {
@@ -4421,7 +4421,7 @@ struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
44214421 const T arg_out) const {
44224422 MPType dout = static_cast <MPType>(arg_dout);
44234423 MPType out = static_cast <MPType>(arg_out);
4424- return static_cast <T>(minus_one_half * dout * out * out * out);
4424+ return static_cast <T>(minus_one_half * dout * ( out * out * out) );
44254425 }
44264426
44274427 static constexpr ActBwdOpFwdDeps FwdDeps () {
You can’t perform that action at this time.
0 commit comments