Skip to content

Commit 6c6474c

Browse files
committed
follow coments
1 parent fcfce48 commit 6c6474c

File tree

3 files changed

+54
-41
lines changed

3 files changed

+54
-41
lines changed

paddle/operators/CMakeLists.txt

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,20 @@ function(op_library TARGET)
5555
set(pybind_flag 1)
5656
endif()
5757

58+
# pool_op contains several operators
5859
if ("${TARGET}" STREQUAL "pool_op")
5960
set(pybind_flag 1)
6061
# It's enough to just adding one operator to pybind
6162
file(APPEND ${pybind_file} "USE_OP(pool2d);\n")
6263
endif()
6364

65+
# pool_with_index_op contains several operators
66+
if ("${TARGET}" STREQUAL "pool_with_index_op")
67+
set(pybind_flag 1)
68+
# It's enough to just adding one operator to pybind
69+
file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n")
70+
endif()
71+
6472
# activation_op contains several operators
6573
if ("${TARGET}" STREQUAL "activation_op")
6674
set(pybind_flag 1)
@@ -75,13 +83,6 @@ function(op_library TARGET)
7583
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n")
7684
endif()
7785

78-
# pool_with_index_op contains several operators
79-
if ("${TARGET}" STREQUAL "pool_with_index_op")
80-
set(pybind_flag 1)
81-
# It's enough to just adding one operator to pybind
82-
file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n")
83-
endif()
84-
8586
# pybind USE_NO_KERNEL_OP
8687
file(READ ${TARGET}.cc TARGET_CONTENT)
8788
string(REGEX MATCH "OperatorWithKernel" regex_result "${TARGET_CONTENT}")

paddle/operators/math/pooling.h

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,16 @@ namespace math {
2424

2525
#define FLT_MAX \
2626
__FLT_MAX__ // It might need to be placed in another file, but I'm still
27-
// wondering where to put it
27+
// wondering where to put it.
2828

2929
/*
3030
* \brief Extracting simple operations from pooling.
31-
* Both MaxPool and AvgPool need initial, compute and finalize operation.
31+
* Both MaxPool and AvgPool need "initial", "compute" and "finalize"
32+
* operation.
3233
* MaxPool initializes temp variable to the negative maximum to find the
3334
* maximum value in the pooling field.
3435
* AvgPool initializes temp variable to the zero to accumulate all values
35-
* in pool pooling, and takes the average.
36+
* in pool pooling, and finally takes the average.
3637
* MaxPoolGrad and AvgPoolGrad are gradient operations respectively.
3738
*/
3839
template <class T>
@@ -72,17 +73,17 @@ class AvgPoolGrad {
7273
/*
7374
* \brief Getting pooling results, and calculating gradient.
7475
*
75-
* In pool2d, all tensors are in NCHW format. In pool3d, all tensors are in
76-
* NCDHW format.
76+
* In pool2d, all tensors are in NCHW format. Where N is batch size, C is the
77+
* number of channels, H and W is the height and width of feature.
78+
* In pool3d, all tensors are in NCDHW format. Where N is batch size, C is the
79+
* number of channels, D, H and W is the depth, height and width of feature.
7780
*
7881
* In max pooling, it is possible that the pooling region has multiple maximum
79-
* elements.
80-
* In this case, we should compute the gradient of the first maximum element.
82+
* elements. In this case, we should compute the gradient of the first maximum
83+
* element.
8184
* This is different from average pooling. So we rewrite the max_pool_grad:
8285
* MaxPool2dGradFunctor, MaxPool3dGradFunctor.
83-
*
8486
*/
85-
8687
template <typename Place, typename PoolProcess, typename T>
8788
class Pool2dFunctor {
8889
public:
@@ -146,10 +147,9 @@ class MaxPool3dGradFunctor {
146147
/*
147148
* \brief Getting max pooling results and corresponding max index, and
148149
* calculating gradient.
149-
* In sub-sampling-pooling, it is necessary to know max element index.
150+
* In up-sampling-pooling, it is necessary to know max element index.
150151
* In pool2d, all tensors are in NCHW format. In pool3d, all tensors are in
151152
* NCDHW format.
152-
*
153153
*/
154154
template <typename Place, typename T>
155155
class MaxPool2dWithIndexFunctor {
@@ -188,6 +188,7 @@ class MaxPool3dWithIndexGradFunctor {
188188
const framework::Tensor& mask, std::vector<int>& ksize,
189189
std::vector<int>& strides, std::vector<int>& paddings);
190190
};
191+
191192
} // namespace math
192193
} // namespace operators
193194
} // namespace paddle

paddle/operators/pool_with_index_op.cc

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
3434
PADDLE_ENFORCE(ctx->HasOutput("Out"),
3535
"Out(Output) of Pooling should not be null.");
3636
PADDLE_ENFORCE(ctx->HasOutput("Mask"),
37-
"Out(Output) of Pooling should not be null.");
37+
"Mask(Output) of Pooling should not be null.");
3838

3939
auto in_x_dims = ctx->GetInputDim("X");
4040

@@ -52,13 +52,11 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
5252
}
5353

5454
PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U,
55-
"Pooling intput size and pooling size should be consistent");
56-
PADDLE_ENFORCE(ksize.size() == 2 || ksize.size() == 3,
57-
"Pooling size size should be 2 elements. or 3 elements.");
55+
"Intput size and pooling size should be consistent.");
5856
PADDLE_ENFORCE_EQ(ksize.size(), strides.size(),
59-
"strides size and pooling size should be the same.");
57+
"Strides size and pooling size should be the same.");
6058
PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(),
61-
"paddings size and pooling size should be the same.");
59+
"Paddings size and pooling size should be the same.");
6260

6361
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
6462
for (size_t i = 0; i < ksize.size(); ++i) {
@@ -76,11 +74,9 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
7674

7775
protected:
7876
void InferShape(framework::InferShapeContextBase *ctx) const override {
79-
PADDLE_ENFORCE(ctx->HasInput("X"),
80-
"X(Input) of Pooling should not be null.");
81-
PADDLE_ENFORCE(
82-
ctx->HasOutput(framework::GradVarName("X")),
83-
"X@GRAD(Input@GRAD) of MaxPoolWithIndexOpGrad should not be null.");
77+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
78+
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
79+
"Input(X@GRAD) should not be null.");
8480
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
8581
}
8682
};
@@ -110,9 +106,10 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
110106

111107
AddAttr<std::vector<int>>(
112108
"ksize",
113-
"Pooling size(height, width) of pooling operator."
109+
"The pooling size(height, width) of pooling operator."
114110
"If globalPooling = true, ksize is ignored and need not be "
115-
"specified."); // TODO(Add checker)
111+
"specified."); // TODO(Chengduo): Add checker. (Currently,
112+
// TypedAttrChecker don't support vector type.)
116113
AddAttr<bool>(
117114
"globalPooling",
118115
"Whether to use the globalPooling."
@@ -123,15 +120,21 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
123120
AddAttr<std::vector<int>>("strides",
124121
"Strides(height, width) of pooling operator."
125122
"Default {1,1}.")
126-
.SetDefault({1, 1}); // TODO(Add checker)
123+
.SetDefault({1, 1}); // TODO(Chengduo): Add checker. (Currently,
124+
// TypedAttrChecker don't support vector type.)
127125
AddAttr<std::vector<int>>("paddings",
128126
"Paddings(height, width) of pooling operator."
129127
"Default {0,0}.")
130-
.SetDefault({0, 0}); // TODO(Add checker)
128+
.SetDefault({0, 0}); // TODO(Chengduo): Add checker. (Currently,
129+
// TypedAttrChecker don't support vector type.)
131130

132131
AddComment(R"DOC(
133-
The maxPooling2d with index operation calculates the output and the mask based on
134-
the input and ksize, strides, paddings parameters.
132+
The maxPooling2d with index operation calculates the output and the mask
133+
based on the input and ksize, strides, paddings parameters. Input(X) and
134+
output(Out, Mask) are in NCHW format. Where N is batch size, C is the
135+
number of channels, H and W is the height and width of feature.
136+
Parameters(ksize, strides, paddings) are two elements.
137+
These two elements represent height and width, respectively.
135138
)DOC");
136139
}
137140
};
@@ -162,9 +165,10 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
162165

163166
AddAttr<std::vector<int>>(
164167
"ksize",
165-
"Pooling size(depth, height, width) of pooling operator."
168+
"The pooling size(depth, height, width) of pooling operator."
166169
"If globalPooling = true, ksize is ignored and need not be "
167-
"specified."); // TODO(Add checker)
170+
"specified."); // TODO(Chengduo): Add checker. (Currently,
171+
// TypedAttrChecker don't support vector type.)
168172
AddAttr<bool>(
169173
"globalPooling",
170174
"Whether to use the globalPooling."
@@ -176,19 +180,26 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
176180
"strides",
177181
"Strides(depth, height, width) of pooling operator."
178182
"Default {1,1,1}.")
179-
.SetDefault({1, 1, 1}); // TODO(Add checker)
183+
.SetDefault({1, 1, 1}); // TODO(Chengduo): Add checker. (Currently,
184+
// TypedAttrChecker don't support vector type.)
180185
AddAttr<std::vector<int>>(
181186
"paddings",
182187
"Paddings(depth, height, width) of pooling operator."
183188
"Default {0,0,0}.")
184-
.SetDefault({0, 0, 0}); // TODO(Add checker)
189+
.SetDefault({0, 0, 0}); // TODO(Chengduo): Add checker. (Currently,
190+
// TypedAttrChecker don't support vector type.)
185191

186192
AddComment(R"DOC(
187-
The maxpooling3d with index operation calculates the output and the mask based on
188-
the input and ksize, strides, paddings parameters.
193+
The maxpooling3d with index operation calculates the output and the mask
194+
based on the input and ksize, strides, paddings parameters.
195+
Input(X) and output(Out, Mask) are in NCDHW format. Where N is batch
196+
size, C is the number of channels, D, H and W is the depth, height and
197+
width of feature. Parameters(ksize, strides, paddings) are three elements.
198+
These three elements represent depth, height and width, respectively.
189199
)DOC");
190200
}
191201
};
202+
192203
} // namespace operators
193204
} // namespace paddle
194205

0 commit comments

Comments
 (0)