diff --git a/src/operator/cudnn_pooling-inl.h b/src/operator/cudnn_pooling-inl.h index 543dc5cc..c7fa214a 100644 --- a/src/operator/cudnn_pooling-inl.h +++ b/src/operator/cudnn_pooling-inl.h @@ -57,6 +57,9 @@ class CuDNNPoolingOp : public Operator { if (!init_cudnn_) { this->Init(s, in_data, out_data); } + if (param_.global_pool) { + this->InitGlobalPool(data.shape_); + } float alpha = 1.0f; float beta = 0.0f; CHECK_EQ(data.CheckContiguous(), true); @@ -109,6 +112,31 @@ class CuDNNPoolingOp : public Operator { } private: + inline void InitGlobalPool(const mshadow::Shape<4> &dshape) { + #if CUDNN_MAJOR == 5 + CHECK_EQ(cudnnSetPooling2dDescriptor(pooling_desc_, + mode_, + nan_prop_, + param_.global_pool ? dshape[2] : param_.kernel[0], + param_.global_pool ? dshape[3] : param_.kernel[1], + param_.pad[0], + param_.pad[1], + param_.global_pool ? 1 : param_.stride[0], + param_.global_pool ? 1 :param_.stride[1]), + CUDNN_STATUS_SUCCESS); + #else + CHECK_EQ(cudnnSetPooling2dDescriptor(pooling_desc_, + mode_, + param_.global_pool ? dshape[2] : param_.kernel[0], + param_.global_pool ? dshape[3] : param_.kernel[1], + param_.pad[0], + param_.pad[1], + param_.global_pool ? 1 : param_.stride[0], + param_.global_pool ? 1 : param_.stride[1]), + CUDNN_STATUS_SUCCESS); + #endif + } + inline void Init(mshadow::Stream *s, const std::vector &in_data, const std::vector &out_data) { diff --git a/src/operator/pooling-inl.h b/src/operator/pooling-inl.h index 51c7125f..ac6190d3 100644 --- a/src/operator/pooling-inl.h +++ b/src/operator/pooling-inl.h @@ -32,8 +32,12 @@ struct PoolingParam : public dmlc::Parameter { TShape stride; TShape pad; int pool_type; + bool global_pool; DMLC_DECLARE_PARAMETER(PoolingParam) { - // TODO(bing) change to only set lower bound + DMLC_DECLARE_FIELD(global_pool).set_default(false) + .describe("Ignore kernel size, do global pooling based on current input feature map. " + "This is useful for input with different shape"); + DMLC_DECLARE_FIELD(kernel) .set_expect_ndim(2).enforce_nonzero() .describe("pooling kernel size: (y, x)"); @@ -81,20 +85,22 @@ class PoolingOp : public Operator { req[pool_enum::kOut], pool(pad(data, param_.pad[0], param_.pad[1]), out_shape, - param_.kernel[0], - param_.kernel[1], - param_.stride[0], - param_.stride[1])); + param_.global_pool ? data.shape_[2] : param_.kernel[0], + param_.global_pool ? data.shape_[3] : param_.kernel[1], + param_.global_pool ? 1 : param_.stride[0], + param_.global_pool ? 1 : param_.stride[1])); } else if (param_.pool_type == pool_enum::kAvgPooling) { Assign(out, req[pool_enum::kOut], - (1.0f / (param_.kernel[0] * param_.kernel[1])) * \ + (1.0f / (param_.global_pool ? + data.shape_[2] * data.shape_[3] : + param_.kernel[0] * param_.kernel[1])) * \ pool(pad(data, param_.pad[0], param_.pad[1]), out_shape, - param_.kernel[0], - param_.kernel[1], - param_.stride[0], - param_.stride[1])); + param_.global_pool ? data.shape_[2] : param_.kernel[0], + param_.global_pool ? data.shape_[3] : param_.kernel[1], + param_.global_pool ? 1 : param_.stride[0], + param_.global_pool ? 1 : param_.stride[1])); } } @@ -126,10 +132,10 @@ class PoolingOp : public Operator { crop(unpool(pad(data, param_.pad[0], param_.pad[1]), pad(output_data, 0, 0), pad(grad, 0, 0), - param_.kernel[0], - param_.kernel[1], - param_.stride[0], - param_.stride[1]), + param_.global_pool ? in_shape[0] : param_.kernel[0], + param_.global_pool ? in_shape[1] : param_.kernel[1], + param_.global_pool ? 1 : param_.stride[0], + param_.global_pool ? 1 : param_.stride[1]), in_shape, param_.pad[0], param_.pad[1])); @@ -139,10 +145,10 @@ class PoolingOp : public Operator { crop(unpool(pad(data, param_.pad[0], param_.pad[1]), pad(output_data, 0, 0), pad(grad, 0, 0), - param_.kernel[0], - param_.kernel[1], - param_.stride[0], - param_.stride[1]), + param_.global_pool ? in_shape[0] : param_.kernel[0], + param_.global_pool ? in_shape[1] : param_.kernel[1], + param_.global_pool ? 1 : param_.stride[0], + param_.global_pool ? 1 : param_.stride[1]), in_shape, param_.pad[0], param_.pad[1])); @@ -177,10 +183,15 @@ class PoolingProp : public OperatorProperty { "Pooling: Input data should be 4D in (batch, channel, y, x)"; TShape oshape = dshape; if (dshape.ndim() == 0) return false; - oshape[2] = std::min(dshape[2] + 2 * param_.pad[0] - param_.kernel[0] + param_.stride[0] - 1, - dshape[2] + 2 * param_.pad[0] - 1) / param_.stride[0] + 1; - oshape[3] = std::min(dshape[3] + 2 * param_.pad[1] - param_.kernel[1] + param_.stride[1] - 1, - dshape[3] + 2 * param_.pad[1] - 1) / param_.stride[1] + 1; + if (param_.global_pool) { + oshape[2] = 1; + oshape[3] = 1; + } else { + oshape[2] = std::min(dshape[2] + 2 * param_.pad[0] - param_.kernel[0] + param_.stride[0] - 1, + dshape[2] + 2 * param_.pad[0] - 1) / param_.stride[0] + 1; + oshape[3] = std::min(dshape[3] + 2 * param_.pad[1] - param_.kernel[1] + param_.stride[1] - 1, + dshape[3] + 2 * param_.pad[1] - 1) / param_.stride[1] + 1; + } CHECK(oshape[2] > 0 && oshape[3] > 0) << "Pooling: kernel size exceed input"; out_shape->clear(); out_shape->push_back(oshape);