Skip to content

Commit

Permalink
[Op] global pool option (#2243)
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed May 27, 2016
1 parent 118b37e commit 179ca3a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 22 deletions.
28 changes: 28 additions & 0 deletions src/operator/cudnn_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class CuDNNPoolingOp : public Operator {
if (!init_cudnn_) {
this->Init(s, in_data, out_data);
}
if (param_.global_pool) {
this->InitGlobalPool(data.shape_);
}
float alpha = 1.0f;
float beta = 0.0f;
CHECK_EQ(data.CheckContiguous(), true);
Expand Down Expand Up @@ -109,6 +112,31 @@ class CuDNNPoolingOp : public Operator {
}

private:
inline void InitGlobalPool(const mshadow::Shape<4> &dshape) {
#if CUDNN_MAJOR == 5
CHECK_EQ(cudnnSetPooling2dDescriptor(pooling_desc_,
mode_,
nan_prop_,
param_.global_pool ? dshape[2] : param_.kernel[0],
param_.global_pool ? dshape[3] : param_.kernel[1],
param_.pad[0],
param_.pad[1],
param_.global_pool ? 1 : param_.stride[0],
param_.global_pool ? 1 :param_.stride[1]),
CUDNN_STATUS_SUCCESS);
#else
CHECK_EQ(cudnnSetPooling2dDescriptor(pooling_desc_,
mode_,
param_.global_pool ? dshape[2] : param_.kernel[0],
param_.global_pool ? dshape[3] : param_.kernel[1],
param_.pad[0],
param_.pad[1],
param_.global_pool ? 1 : param_.stride[0],
param_.global_pool ? 1 : param_.stride[1]),
CUDNN_STATUS_SUCCESS);
#endif
}

inline void Init(mshadow::Stream<gpu> *s,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data) {
Expand Down
55 changes: 33 additions & 22 deletions src/operator/pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
TShape stride;
TShape pad;
int pool_type;
bool global_pool;
DMLC_DECLARE_PARAMETER(PoolingParam) {
// TODO(bing) change to only set lower bound
DMLC_DECLARE_FIELD(global_pool).set_default(false)
.describe("Ignore kernel size, do global pooling based on current input feature map. "
"This is useful for input with different shape");

DMLC_DECLARE_FIELD(kernel)
.set_expect_ndim(2).enforce_nonzero()
.describe("pooling kernel size: (y, x)");
Expand Down Expand Up @@ -81,20 +85,22 @@ class PoolingOp : public Operator {
req[pool_enum::kOut],
pool<Reducer>(pad(data, param_.pad[0], param_.pad[1]),
out_shape,
param_.kernel[0],
param_.kernel[1],
param_.stride[0],
param_.stride[1]));
param_.global_pool ? data.shape_[2] : param_.kernel[0],
param_.global_pool ? data.shape_[3] : param_.kernel[1],
param_.global_pool ? 1 : param_.stride[0],
param_.global_pool ? 1 : param_.stride[1]));
} else if (param_.pool_type == pool_enum::kAvgPooling) {
Assign(out,
req[pool_enum::kOut],
(1.0f / (param_.kernel[0] * param_.kernel[1])) * \
(1.0f / (param_.global_pool ?
data.shape_[2] * data.shape_[3] :
param_.kernel[0] * param_.kernel[1])) * \
pool<Reducer>(pad(data, param_.pad[0], param_.pad[1]),
out_shape,
param_.kernel[0],
param_.kernel[1],
param_.stride[0],
param_.stride[1]));
param_.global_pool ? data.shape_[2] : param_.kernel[0],
param_.global_pool ? data.shape_[3] : param_.kernel[1],
param_.global_pool ? 1 : param_.stride[0],
param_.global_pool ? 1 : param_.stride[1]));
}
}

Expand Down Expand Up @@ -126,10 +132,10 @@ class PoolingOp : public Operator {
crop(unpool<Reducer>(pad(data, param_.pad[0], param_.pad[1]),
pad(output_data, 0, 0),
pad(grad, 0, 0),
param_.kernel[0],
param_.kernel[1],
param_.stride[0],
param_.stride[1]),
param_.global_pool ? in_shape[0] : param_.kernel[0],
param_.global_pool ? in_shape[1] : param_.kernel[1],
param_.global_pool ? 1 : param_.stride[0],
param_.global_pool ? 1 : param_.stride[1]),
in_shape,
param_.pad[0],
param_.pad[1]));
Expand All @@ -139,10 +145,10 @@ class PoolingOp : public Operator {
crop(unpool<Reducer>(pad(data, param_.pad[0], param_.pad[1]),
pad(output_data, 0, 0),
pad(grad, 0, 0),
param_.kernel[0],
param_.kernel[1],
param_.stride[0],
param_.stride[1]),
param_.global_pool ? in_shape[0] : param_.kernel[0],
param_.global_pool ? in_shape[1] : param_.kernel[1],
param_.global_pool ? 1 : param_.stride[0],
param_.global_pool ? 1 : param_.stride[1]),
in_shape,
param_.pad[0],
param_.pad[1]));
Expand Down Expand Up @@ -177,10 +183,15 @@ class PoolingProp : public OperatorProperty {
"Pooling: Input data should be 4D in (batch, channel, y, x)";
TShape oshape = dshape;
if (dshape.ndim() == 0) return false;
oshape[2] = std::min(dshape[2] + 2 * param_.pad[0] - param_.kernel[0] + param_.stride[0] - 1,
dshape[2] + 2 * param_.pad[0] - 1) / param_.stride[0] + 1;
oshape[3] = std::min(dshape[3] + 2 * param_.pad[1] - param_.kernel[1] + param_.stride[1] - 1,
dshape[3] + 2 * param_.pad[1] - 1) / param_.stride[1] + 1;
if (param_.global_pool) {
oshape[2] = 1;
oshape[3] = 1;
} else {
oshape[2] = std::min(dshape[2] + 2 * param_.pad[0] - param_.kernel[0] + param_.stride[0] - 1,
dshape[2] + 2 * param_.pad[0] - 1) / param_.stride[0] + 1;
oshape[3] = std::min(dshape[3] + 2 * param_.pad[1] - param_.kernel[1] + param_.stride[1] - 1,
dshape[3] + 2 * param_.pad[1] - 1) / param_.stride[1] + 1;
}
CHECK(oshape[2] > 0 && oshape[3] > 0) << "Pooling: kernel size exceed input";
out_shape->clear();
out_shape->push_back(oshape);
Expand Down

0 comments on commit 179ca3a

Please sign in to comment.