Skip to content

Commit 7bd5171

Browse files
authored
Add Concat operator with CPU kernel (#3775)
add concat op with CPU kernel
1 parent 4fbc03d commit 7bd5171

File tree

9 files changed

+211
-13
lines changed

9 files changed

+211
-13
lines changed

paddle/operators/concat_op.cc

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/concat_op.h"
16+
#include <vector>
17+
18+
namespace paddle {
19+
namespace operators {
20+
using framework::Tensor;
21+
22+
class ConcatOp : public framework::OperatorWithKernel {
23+
public:
24+
using framework::OperatorWithKernel::OperatorWithKernel;
25+
26+
protected:
27+
void InferShape(const framework::InferShapeContext &ctx) const override {
28+
auto ins = ctx.MultiInput<framework::Tensor>("X");
29+
auto *out = ctx.Output<framework::Tensor>("Out");
30+
size_t axis = static_cast<size_t>(ctx.Attr<int>("axis"));
31+
size_t n = ins.size();
32+
33+
PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1.");
34+
35+
auto out_dims = ins[0]->dims();
36+
size_t in_zero_dims_size = out_dims.size();
37+
for (size_t i = 1; i < n; i++) {
38+
for (size_t j = 0; j < in_zero_dims_size; j++) {
39+
if (j == axis) {
40+
out_dims[axis] += ins[i]->dims()[j];
41+
continue;
42+
}
43+
PADDLE_ENFORCE_EQ(out_dims[j], ins[i]->dims()[j],
44+
"Input tensors should have the same "
45+
"elements except the specify axis.")
46+
}
47+
}
48+
out->Resize(out_dims);
49+
}
50+
};
51+
52+
class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
53+
public:
54+
ConcatOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
55+
: OpProtoAndCheckerMaker(proto, op_checker) {
56+
AddInput("X", "the input tensors of concat operator.").AsDuplicable();
57+
AddOutput("Out", "the output tensor of concat operator.");
58+
AddComment(R"DOC(
59+
Join the input tensors along with the axis.
60+
Examples:
61+
Input[0] = [[1,2],[3,4]]
62+
Input[1] = [[5,6]]
63+
axis = 0
64+
Output = [[1,2],
65+
[3,4],
66+
[5,6]]
67+
)DOC");
68+
AddAttr<int>("axis", "The axis which the inputs will be joined with.")
69+
.SetDefault(0);
70+
}
71+
};
72+
73+
} // namespace operators
74+
} // namespace paddle
75+
76+
namespace ops = paddle::operators;
77+
REGISTER_OP_WITHOUT_GRADIENT(concat, ops::ConcatOp, ops::ConcatOpMaker)
78+
REGISTER_OP_CPU_KERNEL(concat,
79+
ops::ConcatKernel<paddle::platform::CPUPlace, float>)

paddle/operators/concat_op.cu

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#define EIGEN_USE_GPU
16+
#include "paddle/operators/concat_op.h"
17+
18+
namespace ops = paddle::operators;
19+
// TODO(Yancey1989) Add GPU kernel

paddle/operators/concat_op.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <vector>
18+
#include "paddle/framework/op_registry.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
template <typename Place, typename T>
24+
class ConcatKernel : public framework::OpKernel {
25+
public:
26+
void Compute(const framework::ExecutionContext& ctx) const override {
27+
auto ins = ctx.MultiInput<framework::Tensor>("X");
28+
auto* out = ctx.Output<framework::Tensor>("Out");
29+
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
30+
size_t n = ins.size();
31+
size_t output_axis_dim = 0;
32+
size_t before = 1, after = 1;
33+
for (size_t i = 0; i < n; i++) {
34+
output_axis_dim += ins[i]->dims()[axis];
35+
}
36+
auto& input_zero = ins[0];
37+
for (int64_t i = 0; i < input_zero->dims().size(); i++) {
38+
if (i == axis) {
39+
continue;
40+
}
41+
if (i < axis) {
42+
before *= input_zero->dims()[i];
43+
} else {
44+
after *= input_zero->dims()[i];
45+
}
46+
}
47+
size_t output_offset = 0;
48+
for (size_t i = 0; i < n; i++) {
49+
auto& in = ins[i];
50+
auto axis_dim = in->dims()[axis];
51+
for (size_t j = 0; j < before; j++) {
52+
size_t len = axis_dim * after * sizeof(T);
53+
const T* src = in->data<T>() + axis_dim * after * j;
54+
T* out_data = out->mutable_data<T>(platform::CPUPlace());
55+
T* dest = out_data + output_offset + output_axis_dim * after * j;
56+
memcpy(dest, src, len);
57+
}
58+
output_offset += axis_dim * after;
59+
}
60+
}
61+
};
62+
63+
} // namespace operators
64+
} // namespace paddle

paddle/pybind/pybind.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ USE_OP(minus);
4949
USE_OP(cos_sim);
5050
USE_CPU_ONLY_OP(gather);
5151
USE_CPU_ONLY_OP(scatter);
52+
USE_CPU_ONLY_OP(concat);
5253
USE_OP(top_k);
5354
USE_OP(squared_l2_distance);
5455
USE_OP(sum);

python/paddle/v2/framework/op.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def __call__(self, *args, **kwargs):
4343
if len(args) != 0:
4444
raise ValueError("Only keyword arguments are supported.")
4545
op_desc = framework_pb2.OpDesc()
46-
4746
for input_parameter in self.__op_proto__.inputs:
4847
input_arguments = kwargs.get(input_parameter.name, [])
4948
if is_str(input_arguments):

python/paddle/v2/framework/tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ py_test(test_lookup_table SRCS test_lookup_table.py)
3535
py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py)
3636
py_test(test_sum_op SRCS test_sum_op.py)
3737
py_test(mnist SRCS mnist.py)
38+
py_test(test_concat_op SRCS test_concat_op.py)
3839
py_test(test_squared_l2_distance_op SRCS test_squared_l2_distance_op.py)

python/paddle/v2/framework/tests/gradient_checker.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111
def create_op(op_type):
1212
# TODO need to set attrs
1313
kwargs = dict()
14-
for in_name in Operator.get_op_input_names(op_type):
14+
for in_name, _ in Operator.get_op_input_names(op_type):
1515
kwargs[in_name] = in_name
16-
for out_name in Operator.get_op_output_names(op_type):
16+
for out_name, _ in Operator.get_op_output_names(op_type):
1717
kwargs[out_name] = out_name
18-
1918
return Operator(op_type, **kwargs)
2019

2120

python/paddle/v2/framework/tests/op_test_util.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,30 @@ def test_all(self):
2727
places.append(core.GPUPlace(0))
2828

2929
for place in places:
30-
for in_name in Operator.get_op_input_names(self.type):
31-
if hasattr(self, "inputs") and in_name in self.inputs:
32-
kwargs[in_name] = in_name
33-
var = scope.new_var(in_name).get_tensor()
34-
arr = self.inputs[in_name]
35-
var.set_dims(arr.shape)
36-
var.set(arr, place)
30+
for ins in Operator.get_op_input_names(self.type):
31+
in_name = ins[0]
32+
in_dup = ins[1]
33+
if hasattr(self, 'inputs') and in_name in self.inputs:
34+
kwargs[in_name] = []
35+
if in_dup:
36+
arrays = self.inputs[in_name]
37+
for index, arr in enumerate(arrays):
38+
var = scope.new_var(in_name + str(index))
39+
tensor = var.get_tensor()
40+
tensor.set_dims(arr.shape)
41+
tensor.set(arr, place)
42+
kwargs[in_name].append(in_name + str(index))
43+
else:
44+
kwargs[in_name] = in_name
45+
var = scope.new_var(in_name).get_tensor()
46+
arr = self.inputs[in_name]
47+
var.set_dims(arr.shape)
48+
var.set(arr, place)
3749
else:
3850
kwargs[in_name] = "@EMPTY@"
3951

40-
for out_name in Operator.get_op_output_names(self.type):
52+
for out_name, out_dup in Operator.get_op_output_names(
53+
self.type):
4154
if not hasattr(self, "outputs"):
4255
raise ValueError(
4356
"The test op must set self.outputs dict.")
@@ -60,7 +73,8 @@ def test_all(self):
6073
ctx = core.DeviceContext.create(place)
6174
op.run(scope, ctx)
6275

63-
for out_name in Operator.get_op_output_names(self.type):
76+
for out_name, out_dup in Operator.get_op_output_names(
77+
self.type):
6478
actual = numpy.array(scope.find_var(out_name).get_tensor())
6579
expect = self.outputs[out_name]
6680
self.assertTrue(
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import unittest
2+
import numpy as np
3+
from gradient_checker import GradientChecker, create_op
4+
from op_test_util import OpTestMeta
5+
6+
7+
class TestConcatOp(unittest.TestCase):
8+
__metaclass__ = OpTestMeta
9+
10+
def setUp(self):
11+
self.type = "concat"
12+
x0 = np.random.random((2, 3, 2, 5)).astype('float32')
13+
x1 = np.random.random((2, 3, 3, 5)).astype('float32')
14+
x2 = np.random.random((2, 3, 4, 5)).astype('float32')
15+
axis = 2
16+
self.inputs = {'X': [x0, x1, x2]}
17+
self.attrs = {'axis': axis}
18+
self.outputs = {'Out': np.concatenate((x0, x1, x2), axis=axis)}
19+
20+
21+
if __name__ == '__main__':
22+
unittest.main()

0 commit comments

Comments
 (0)