@@ -19,8 +19,8 @@ limitations under the License. */
19
19
namespace paddle {
20
20
namespace operators {
21
21
22
- inline void GetDims (const framework::DDim& dim, int axis,
23
- int * pre, int * n, int * post) {
22
+ inline void GetDims (const framework::DDim& dim, int axis, int * pre, int * n,
23
+ int * post) {
24
24
*pre = 1 ;
25
25
*post = 1 ;
26
26
*n = dim[axis];
@@ -49,7 +49,7 @@ class NormKernel : public framework::OpKernel<T> {
49
49
if (axis < 0 ) axis = xdim.size () + axis;
50
50
int pre, n, post;
51
51
GetDims (xdim, axis, &pre, &n, &post);
52
-
52
+
53
53
auto * place = ctx.template device_context <DeviceContext>().eigen_device ();
54
54
55
55
Eigen::DSizes<int , 3 > shape (pre, n, post);
@@ -61,14 +61,13 @@ class NormKernel : public framework::OpKernel<T> {
61
61
auto x = x_e.reshape (shape);
62
62
auto y = y_e.reshape (shape);
63
63
auto norm = norm_e.reshape (norm_shape);
64
-
64
+
65
65
Eigen::DSizes<int , 1 > rdim (1 );
66
66
auto x_pow = x * x;
67
67
auto & device_ctx = ctx.template device_context <DeviceContext>();
68
68
math::SetConstant<DeviceContext, T>()(device_ctx, out_norm, eps);
69
69
70
70
// y = x / sqrt((sum(x * x) + epsilon))
71
-
72
71
// norm = sqrt(sum(x * x) + epsilon)
73
72
norm.device (*place) = norm + x_pow.eval ().sum (rdim) + eps;
74
73
norm.device (*place) = norm.sqrt ();
@@ -93,9 +92,8 @@ class NormGradKernel : public framework::OpKernel<T> {
93
92
if (axis < 0 ) axis = xdim.size () + axis;
94
93
int pre, n, post;
95
94
GetDims (xdim, axis, &pre, &n, &post);
96
-
97
- auto * place =
98
- ctx.template device_context <DeviceContext>().eigen_device ();
95
+
96
+ auto * place = ctx.template device_context <DeviceContext>().eigen_device ();
99
97
100
98
auto x_e = framework::EigenVector<T>::Flatten (*in_x);
101
99
auto dy_e = framework::EigenVector<T>::Flatten (*in_dy);
@@ -112,7 +110,7 @@ class NormGradKernel : public framework::OpKernel<T> {
112
110
framework::Tensor rsum;
113
111
rsum.mutable_data <T>({pre, post}, ctx.GetPlace ());
114
112
auto sum = framework::EigenTensor<T, 2 >::From (rsum);
115
-
113
+
116
114
Eigen::DSizes<int , 1 > rdim (1 );
117
115
Eigen::DSizes<int , 3 > bcast (1 , n, 1 );
118
116
Eigen::DSizes<int , 3 > rshape (pre, 1 , post);
0 commit comments