Skip to content

Commit 7a891a3

Browse files
authored
Merge pull request #4042 from hedaoyuan/conv_op
Convolution operator
2 parents b5e67fc + 6c0129a commit 7a891a3

File tree

5 files changed

+493
-9
lines changed

5 files changed

+493
-9
lines changed

paddle/framework/tensor_impl.h

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,19 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
130130
PADDLE_ENFORCE_LE(end_idx, dims_[0], "Slice end index is out of bound.");
131131
PADDLE_ENFORCE_LT(begin_idx, end_idx,
132132
"Begin index must be less than end index.");
133-
PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1.");
134-
size_t base = numel() / dims_[0];
135-
Tensor dst;
136-
dst.holder_ = holder_;
137-
DDim dst_dims = dims_;
138-
dst_dims[0] = end_idx - begin_idx;
139-
dst.Resize(dst_dims);
140-
dst.offset_ = offset_ + begin_idx * base * sizeof(T);
141-
return dst;
133+
134+
if (dims_[0] == 1) {
135+
return *this;
136+
} else {
137+
size_t base = numel() / dims_[0];
138+
Tensor dst;
139+
dst.holder_ = holder_;
140+
DDim dst_dims = dims_;
141+
dst_dims[0] = end_idx - begin_idx;
142+
dst.Resize(dst_dims);
143+
dst.offset_ = offset_ + begin_idx * base * sizeof(T);
144+
return dst;
145+
}
142146
}
143147

144148
inline Tensor& Tensor::Resize(const DDim& dims) {

paddle/operators/conv2d_op.cc

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/gemm_conv2d_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
int outputSize(int input_size, int filter_size, int padding, int stride) {
21+
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
22+
return output_size;
23+
}
24+
25+
class Conv2DOp : public framework::OperatorWithKernel {
26+
public:
27+
using framework::OperatorWithKernel::OperatorWithKernel;
28+
29+
protected:
30+
void InferShape(const framework::InferShapeContext &ctx) const override {
31+
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"),
32+
"Input(Input) of Conv2DOp should not be null.");
33+
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Filter"),
34+
"Input(Filter) of Conv2DOp should not be null.");
35+
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"),
36+
"Output(Output) of Conv2DOp should not be null.");
37+
38+
auto in = ctx.Input<Tensor>("Input");
39+
auto filter = ctx.Input<Tensor>("Filter");
40+
auto out = ctx.Output<framework::LoDTensor>("Output");
41+
std::vector<int> strides = Attr<std::vector<int>>("strides");
42+
std::vector<int> paddings = Attr<std::vector<int>>("paddings");
43+
int groups = Attr<int>("groups");
44+
int input_channels = in->dims()[1];
45+
int output_channels = filter->dims()[0];
46+
47+
PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp input should be 4-D.");
48+
PADDLE_ENFORCE_EQ(filter->dims().size(), 4,
49+
"Conv2DOp filter should be 4-D.");
50+
PADDLE_ENFORCE_EQ(input_channels, filter->dims()[1] * groups,
51+
"The number of input channels should be equal to filter "
52+
"channels * groups.");
53+
PADDLE_ENFORCE_EQ(
54+
output_channels % groups, 0,
55+
"The number of output channels should be divided by groups.");
56+
57+
auto output_height =
58+
outputSize(in->dims()[2], filter->dims()[2], paddings[0], strides[0]);
59+
auto output_width =
60+
outputSize(in->dims()[3], filter->dims()[3], paddings[1], strides[1]);
61+
out->Resize(
62+
{in->dims()[0], filter->dims()[0], output_height, output_width});
63+
}
64+
};
65+
66+
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
67+
public:
68+
Conv2DOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
69+
: OpProtoAndCheckerMaker(proto, op_checker) {
70+
AddInput(
71+
"Input",
72+
"The input tensor of convolution operator. "
73+
"The format of input tensor is NCHW. Where N is batch size, C is the "
74+
"number of channels, H and W is the height and width of image.");
75+
AddInput(
76+
"Filter",
77+
"The filter tensor of convolution operator."
78+
"The format of the filter tensor is MCHW, where M is the number of "
79+
"output image channels, C is the number of input image channels, "
80+
"H and W is height and width of filter. "
81+
"If the groups attribute is greater than 1, C equal the number of "
82+
"input image channels divided by the groups.");
83+
AddOutput("Output",
84+
"The output tensor of convolution operator."
85+
"The format of output tensor is also NCHW.");
86+
AddAttr<std::vector<int>>("strides", "strides of convolution operator.")
87+
.SetDefault({1, 1});
88+
AddAttr<std::vector<int>>("paddings", "paddings of convolution operator.")
89+
.SetDefault({0, 0});
90+
AddAttr<int>(
91+
"groups",
92+
"group size of convolution operator. "
93+
"Refer to grouped convolution in Alex Krizhevsky's paper: "
94+
"when group=2, the first half of the filters are only connected to the "
95+
"first half of the input channels, and the second half only connected "
96+
"to the second half.")
97+
.SetDefault(1);
98+
AddComment(R"DOC(
99+
The convolution operation calculates the output based on the input, filter
100+
and strides, paddings, groups parameters. The size of each dimension of the
101+
parameters is checked in the infer-shape.
102+
)DOC");
103+
}
104+
};
105+
106+
class Conv2DOpGrad : public framework::OperatorWithKernel {
107+
public:
108+
using framework::OperatorWithKernel::OperatorWithKernel;
109+
110+
protected:
111+
void InferShape(const framework::InferShapeContext &ctx) const override {
112+
auto in = ctx.Input<Tensor>("Input");
113+
auto filter = ctx.Input<Tensor>("Filter");
114+
auto d_in =
115+
ctx.Output<framework::LoDTensor>(framework::GradVarName("Input"));
116+
auto d_filter =
117+
ctx.Output<framework::LoDTensor>(framework::GradVarName("Filter"));
118+
if (d_in) d_in->Resize(in->dims());
119+
if (d_filter) d_filter->Resize(filter->dims());
120+
}
121+
};
122+
123+
} // namespace operators
124+
} // namespace paddle
125+
126+
namespace ops = paddle::operators;
127+
REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOpMaker, conv2d_grad,
128+
ops::Conv2DOpGrad);
129+
130+
REGISTER_OP_CPU_KERNEL(
131+
conv2d, ops::GemmConv2DKernel<paddle::platform::CPUPlace, float>);
132+
REGISTER_OP_CPU_KERNEL(
133+
conv2d_grad, ops::GemmConvGrad2DKernel<paddle::platform::CPUPlace, float>);

paddle/operators/conv2d_op.cu

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/gemm_conv2d_op.h"
16+
17+
namespace ops = paddle::operators;
18+
19+
REGISTER_OP_GPU_KERNEL(
20+
conv2d, ops::GemmConv2DKernel<paddle::platform::GPUPlace, float>);
21+
REGISTER_OP_GPU_KERNEL(
22+
conv2d_grad, ops::GemmConvGrad2DKernel<paddle::platform::GPUPlace, float>);

0 commit comments

Comments
 (0)