Skip to content

Commit 3a2724b

Browse files
committed
Upgrade pre-commit to fix the code style.
1 parent 0f66adf commit 3a2724b

File tree

3 files changed

+10
-17
lines changed

3 files changed

+10
-17
lines changed

paddle/fluid/operators/norm_op.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@ namespace operators {
1919
class NormOpMaker : public framework::OpProtoAndCheckerMaker {
2020
public:
2121
void Make() override {
22-
AddInput("X",
23-
"(Tensor) A tensor of rank >= axis.");
22+
AddInput("X", "(Tensor) A tensor of rank >= axis.");
2423
AddAttr<int>("axis",
2524
"The axis on which to apply normalization. If axis < 0, "
2625
"the dimension to normalization is rank(X) + axis. -1 is "
27-
"the last dimension.");
26+
"the last dimension.");
2827
AddAttr<float>("epsilon",
2928
"(float, default 1e-10) The epsilon value is used "
3029
"to avoid division by zero.")
@@ -33,8 +32,7 @@ class NormOpMaker : public framework::OpProtoAndCheckerMaker {
3332
"(Tensor) A tensor saved the `sqrt(sum(x) + epsion)` will "
3433
"be used in backward kernel.")
3534
.AsIntermediate();
36-
AddOutput("Out",
37-
"(Tensor) A tensor of the same shape as X.");
35+
AddOutput("Out", "(Tensor) A tensor of the same shape as X.");
3836
AddComment(R"DOC(
3937
4038
Given a tensor, apply 2-normalization along the provided axis.

paddle/fluid/operators/norm_op.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ limitations under the License. */
1919
namespace paddle {
2020
namespace operators {
2121

22-
inline void GetDims(const framework::DDim& dim, int axis,
23-
int* pre, int* n, int* post) {
22+
inline void GetDims(const framework::DDim& dim, int axis, int* pre, int* n,
23+
int* post) {
2424
*pre = 1;
2525
*post = 1;
2626
*n = dim[axis];
@@ -49,7 +49,7 @@ class NormKernel : public framework::OpKernel<T> {
4949
if (axis < 0) axis = xdim.size() + axis;
5050
int pre, n, post;
5151
GetDims(xdim, axis, &pre, &n, &post);
52-
52+
5353
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
5454

5555
Eigen::DSizes<int, 3> shape(pre, n, post);
@@ -61,14 +61,13 @@ class NormKernel : public framework::OpKernel<T> {
6161
auto x = x_e.reshape(shape);
6262
auto y = y_e.reshape(shape);
6363
auto norm = norm_e.reshape(norm_shape);
64-
64+
6565
Eigen::DSizes<int, 1> rdim(1);
6666
auto x_pow = x * x;
6767
auto& device_ctx = ctx.template device_context<DeviceContext>();
6868
math::SetConstant<DeviceContext, T>()(device_ctx, out_norm, eps);
6969

7070
// y = x / sqrt((sum(x * x) + epsilon))
71-
7271
// norm = sqrt(sum(x * x) + epsilon)
7372
norm.device(*place) = norm + x_pow.eval().sum(rdim) + eps;
7473
norm.device(*place) = norm.sqrt();
@@ -93,9 +92,8 @@ class NormGradKernel : public framework::OpKernel<T> {
9392
if (axis < 0) axis = xdim.size() + axis;
9493
int pre, n, post;
9594
GetDims(xdim, axis, &pre, &n, &post);
96-
97-
auto* place =
98-
ctx.template device_context<DeviceContext>().eigen_device();
95+
96+
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
9997

10098
auto x_e = framework::EigenVector<T>::Flatten(*in_x);
10199
auto dy_e = framework::EigenVector<T>::Flatten(*in_dy);
@@ -112,7 +110,7 @@ class NormGradKernel : public framework::OpKernel<T> {
112110
framework::Tensor rsum;
113111
rsum.mutable_data<T>({pre, post}, ctx.GetPlace());
114112
auto sum = framework::EigenTensor<T, 2>::From(rsum);
115-
113+
116114
Eigen::DSizes<int, 1> rdim(1);
117115
Eigen::DSizes<int, 3> bcast(1, n, 1);
118116
Eigen::DSizes<int, 3> rshape(pre, 1, post);

python/paddle/fluid/tests/unittests/test_layers.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,6 @@ def test_polygon_box_transform(self):
387387
self.assertIsNotNone(output)
388388
print(str(program))
389389

390-
391390
def test_l2_normalize(self):
392391
program = Program()
393392
with program_guard(program):
@@ -397,7 +396,5 @@ def test_l2_normalize(self):
397396
print(str(program))
398397

399398

400-
401-
402399
if __name__ == '__main__':
403400
unittest.main()

0 commit comments

Comments
 (0)