Skip to content

Commit 62aedce

Browse files
authored
Merge pull request PaddlePaddle#3553 from reyoung/feature/unittest_for_mean_grad
Add MeanOp's Gradient Test And Fix Mean Op Gradient
2 parents c68bfc3 + 7f8c3f8 commit 62aedce

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

paddle/operators/mean_op.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ class MeanGradKernel : public framework::OpKernel {
5555
IG->mutable_data<T>(context.GetPlace());
5656

5757
T ig_size = (T)framework::product(IG->dims());
58+
Eigen::DSizes<int, 1> bcast(ig_size);
5859

5960
EigenVector<T>::Flatten(*IG).device(context.GetEigenDevice<Place>()) =
60-
EigenScalar<T>::From(*OG) / ig_size;
61+
(EigenVector<T>::From(*OG) / ig_size).broadcast(bcast);
6162
}
6263
};
6364

python/paddle/v2/framework/tests/test_mean_op.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22
from op_test_util import OpTestMeta
3+
from gradient_checker import GradientChecker, create_op
34
import numpy as np
45

56

@@ -12,5 +13,12 @@ def setUp(self):
1213
self.outputs = {'Out': np.mean(self.inputs['X'])}
1314

1415

16+
class MeanGradOpTest(GradientChecker):
17+
def test_normal(self):
18+
op = create_op("mean")
19+
inputs = {"X": np.random.random((10, 10)).astype("float32")}
20+
self.check_grad(op, inputs, set("X"), "Out")
21+
22+
1523
if __name__ == '__main__':
1624
unittest.main()

0 commit comments

Comments
 (0)