Skip to content

Commit 7e60706

Browse files
authored
Merge pull request PaddlePaddle#3135 from gangliao/mean_op_op
cpu/gpu mean op and its unit test
2 parents f70e807 + 6e0661c commit 7e60706

File tree

9 files changed

+153
-10
lines changed

9 files changed

+153
-10
lines changed

paddle/operators/CMakeLists.txt

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,26 @@ endfunction()
4444
op_library(add_op SRCS add_op.cc add_op.cu)
4545
cc_test(add_op_test SRCS add_op_test.cc DEPS add_op)
4646

47+
op_library(mean_op SRCS mean_op.cc mean_op.cu)
48+
cc_test(mean_op_test SRCS mean_op_test.cc DEPS mean_op)
49+
4750
op_library(mul_op SRCS mul_op.cc mul_op.cu)
4851
op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)
49-
op_library(sigmoid_op SRCS sigmoid_op.cu sigmoid_op.cc)
52+
53+
op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu)
5054
op_library(softmax_op SRCS softmax_op.cc softmax_op.cu)
5155
op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu)
5256
op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu)
5357

54-
op_library(fc_op SRCS fc_op.cc DEPS mul_op rowwise_add_op sigmoid_op
55-
softmax_op net)
56-
5758
op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)
5859

59-
op_library(recurrent_network_op SRCS recurrent_network_op.cc DEPS op_desc
60-
tensor op_registry operator net)
61-
cc_test(recurrent_network_op_test SRCS recurrent_network_op_test.cc DEPS
62-
recurrent_network_op gtest mul_op add_op)
60+
op_library(fc_op
61+
SRCS fc_op.cc
62+
DEPS mul_op rowwise_add_op sigmoid_op softmax_op net)
63+
64+
op_library(recurrent_network_op
65+
SRCS recurrent_network_op.cc
66+
DEPS op_desc tensor net)
67+
cc_test(recurrent_network_op_test
68+
SRCS recurrent_network_op_test.cc
69+
DEPS recurrent_network_op mul_op add_op)

paddle/operators/mean_op.cc

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/mean_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class MeanOp : public OperatorWithKernel {
21+
protected:
22+
void InferShape(const InferShapeContext &ctx) const override {
23+
PADDLE_ENFORCE(ctx.InputSize() == 1, "Input size of AddOp must be one");
24+
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one");
25+
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.OutputVar(0) != nullptr,
26+
"Input/Output of MeanOp must be initialized.");
27+
ctx.Output<Tensor>(0)->Resize(framework::make_ddim({1}));
28+
}
29+
};
30+
31+
class MeanOpMaker : public OpProtoAndCheckerMaker {
32+
public:
33+
MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker)
34+
: OpProtoAndCheckerMaker(proto, op_checker) {
35+
AddInput("X", "The input of mean op");
36+
AddOutput("Out", "The output of mean op");
37+
AddComment("Mean Operator");
38+
}
39+
};
40+
41+
} // namespace operators
42+
} // namespace paddle
43+
44+
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker);
45+
REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel<ops::CPUPlace, float>);

paddle/operators/mean_op.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#define EIGEN_USE_GPU
2+
3+
#include "paddle/operators/mean_op.h"
4+
5+
REGISTER_OP_GPU_KERNEL(mean, ops::MeanKernel<ops::GPUPlace, float>);

paddle/operators/mean_op.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/operators/type_alias.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
template <typename Place, typename T>
22+
class MeanKernel : public OpKernel {
23+
public:
24+
void Compute(const ExecutionContext& context) const override {
25+
auto input = context.Input<Tensor>(0);
26+
auto output = context.Output<Tensor>(0);
27+
28+
output->mutable_data<T>(context.GetPlace());
29+
30+
EigenScalar<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
31+
EigenVector<T>::Flatten(*input).mean();
32+
}
33+
};
34+
35+
} // namespace operators
36+
} // namespace paddle

paddle/operators/mean_op_test.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <gtest/gtest.h>
16+
17+
#include <paddle/framework/op_registry.h>
18+
19+
USE_OP(mean);
20+
21+
TEST(MeanOp, GetOpProto) {
22+
auto& protos = paddle::framework::OpRegistry::protos();
23+
auto it = protos.find("mean");
24+
ASSERT_NE(it, protos.end());
25+
}

paddle/pybind/CMakeLists.txt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,9 @@
1-
cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python
2-
add_op fc_op sgd_op cross_entropy_op recurrent_network_op)
1+
cc_library(paddle_pybind SHARED
2+
SRCS pybind.cc
3+
DEPS pybind python
4+
fc_op
5+
sgd_op
6+
add_op
7+
mean_op
8+
cross_entropy_op
9+
recurrent_network_op)

paddle/pybind/pybind.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ USE_OP(onehot_cross_entropy);
3333
USE_OP_WITHOUT_KERNEL(fc);
3434
USE_OP(sgd);
3535
USE_OP(mul);
36+
USE_OP(mean);
3637
USE_OP(sigmoid);
3738
USE_OP(softmax);
3839
USE_OP(rowwise_add);

python/paddle/v2/framework/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_python_test(test_framework
1010
test_sgd_op.py
1111
test_cross_entropy_op.py
1212
test_mul_op.py
13+
test_mean_op.py
1314
test_sigmoid_op.py
1415
test_softmax_op.py
1516
test_rowwise_add_op.py
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import unittest
2+
from op_test_util import OpTestMeta
3+
import numpy as np
4+
5+
6+
class TestMeanOp(unittest.TestCase):
7+
__metaclass__ = OpTestMeta
8+
9+
def setUp(self):
10+
self.type = "mean"
11+
self.X = np.random.random((32, 784)).astype("float32")
12+
self.Out = np.mean(self.X)
13+
14+
15+
if __name__ == '__main__':
16+
unittest.main()

0 commit comments

Comments
 (0)