Skip to content

Enhance AvgPooling to support both include_mode and exclude_mode #6100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 6, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions paddle/cuda/include/hl_cnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ extern void hl_maxpool_backward(const int frameCnt,
* @param[in] paddingW padding width.
* @param[out] tgtData output data.
* @param[in] tgtStride stride between output data samples.
* @param[in] excludeMode whether to consider paddings for size.
*
*/
extern void hl_avgpool_forward(const int frameCnt,
Expand All @@ -132,7 +133,8 @@ extern void hl_avgpool_forward(const int frameCnt,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride);
const int tgtStride,
bool excludeMode);

/**
* @brief Maximum pool backward.
Expand All @@ -154,6 +156,7 @@ extern void hl_avgpool_forward(const int frameCnt,
* @param[in] scaleB scale.
* @param[out] backGrad output grad.
* @param[in] outStride stride between output data samples.
* @param[in] excludeMode whether to consider paddings for size.
*
*/
extern void hl_avgpool_backward(const int frameCnt,
Expand All @@ -172,7 +175,8 @@ extern void hl_avgpool_backward(const int frameCnt,
real scaleA,
real scaleB,
real* backGrad,
const int outStride);
const int outStride,
bool excludeMode);

extern void hl_maxpool3D_forward(const int frameCnt,
const real* inputData,
Expand Down
6 changes: 4 additions & 2 deletions paddle/cuda/include/stub/hl_cnn_stub.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ inline void hl_avgpool_forward(const int frameCnt,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride) {}
const int tgtStride,
const bool excludeMode) {}

inline void hl_avgpool_backward(const int frameCnt,
const real* outGrad,
Expand All @@ -86,7 +87,8 @@ inline void hl_avgpool_backward(const int frameCnt,
real scaleA,
real scaleB,
real* backGrad,
const int outStride) {}
const int outStride,
const bool excludeMode) {}

inline void hl_maxpool3D_forward(const int frameCnt,
const real* inputData,
Expand Down
28 changes: 18 additions & 10 deletions paddle/cuda/src/hl_cuda_cnn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ __global__ void KeAvgPoolForward(const int nthreads,
const int padH,
const int padW,
real* tgtData,
const int tgtStride) {
const int tgtStride,
const bool excludeMode) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
int pw = index % pooledW;
Expand All @@ -224,7 +225,8 @@ __global__ void KeAvgPoolForward(const int nthreads,
int wend = min(wstart + sizeX, width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
int pool_size = (hend - hstart) * (wend - wstart);
int poolSize =
excludeMode ? (hend - hstart) * (wend - wstart) : sizeY * sizeX;

real aveval = 0;
inputData += (frameNum * channels + c) * height * width;
Expand All @@ -235,7 +237,7 @@ __global__ void KeAvgPoolForward(const int nthreads,
}
int tgtIndex =
index % (pooledW * pooledH * channels) + frameNum * tgtStride;
tgtData[tgtIndex] = aveval / pool_size;
tgtData[tgtIndex] = aveval / poolSize;
}
}

Expand All @@ -253,7 +255,8 @@ void hl_avgpool_forward(const int frameCnt,
const int paddingH,
const int paddingW,
real* tgtData,
const int tgtStride) {
const int tgtStride,
const bool excludeMode) {
int num_kernels = pooledH * pooledW * channels * frameCnt;
int blocks = (num_kernels + 1024 - 1) / 1024;
KeAvgPoolForward<<<blocks, 1024, 0, STREAM_DEFAULT>>>(num_kernels,
Expand All @@ -270,7 +273,8 @@ void hl_avgpool_forward(const int frameCnt,
paddingH,
paddingW,
tgtData,
tgtStride);
tgtStride,
excludeMode);
CHECK_SYNC("hl_avgpool_forward failed");
}

Expand All @@ -290,7 +294,8 @@ __global__ void KeAvgPoolBackward(const int nthreads,
real scaleA,
real scaleB,
real* tgtGrad,
const int outStride) {
const int outStride,
const bool excludeMode) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < nthreads) {
int offsetW = index % width + padW;
Expand All @@ -314,8 +319,9 @@ __global__ void KeAvgPoolBackward(const int nthreads,
int wstart = pw * strideW - padW;
int wend = min(wstart + sizeX, width);
wstart = max(wstart, 0);
int poolsize = (hend - hstart) * (wend - wstart);
gradient += outGrad[ph * pooledW + pw] / poolsize;
int poolSize =
excludeMode ? (hend - hstart) * (wend - wstart) : sizeY * sizeX;
gradient += outGrad[ph * pooledW + pw] / poolSize;
}
}
tgtGrad[index] = scaleB * tgtGrad[index] + scaleA * gradient;
Expand All @@ -338,7 +344,8 @@ void hl_avgpool_backward(const int frameCnt,
real scaleA,
real scaleB,
real* backGrad,
const int outStride) {
const int outStride,
const bool excludeMode) {
int num_kernels = height * width * channels * frameCnt;
int blocks = (num_kernels + 1024 - 1) / 1024;

Expand All @@ -358,7 +365,8 @@ void hl_avgpool_backward(const int frameCnt,
scaleA,
scaleB,
backGrad,
outStride);
outStride,
excludeMode);
CHECK_SYNC("hl_avgpool_backward failed");
}

Expand Down
2 changes: 2 additions & 0 deletions paddle/gserver/layers/PoolLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ bool PoolLayer::init(const LayerMap& layerMap,
strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride();
confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding();
outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x();

excludeMode_ = conf.has_exclude_mode() ? conf.exclude_mode() : true;
return true;
}

Expand Down
2 changes: 2 additions & 0 deletions paddle/gserver/layers/PoolLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class PoolLayer : public Layer {

std::string poolType_;

bool excludeMode_;

public:
explicit PoolLayer(const LayerConfig& config) : Layer(config) {}

Expand Down
8 changes: 6 additions & 2 deletions paddle/gserver/layers/PoolProjection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ PoolProjection::PoolProjection(const ProjectionConfig& config,
strideY_ = conf.has_stride_y() ? conf.stride_y() : conf.stride();
confPaddingY_ = conf.has_padding_y() ? conf.padding_y() : conf.padding();
outputY_ = conf.has_output_y() ? conf.output_y() : conf.output_x();

excludeMode_ = conf.has_exclude_mode() ? conf.exclude_mode() : true;
}

size_t PoolProjection::getSize() {
Expand Down Expand Up @@ -141,7 +143,8 @@ void AvgPoolProjection::forward() {
outputY_,
outputX_,
confPaddingY_,
confPadding_);
confPadding_,
excludeMode_);
}

void AvgPoolProjection::backward(const UpdateCallback& callback) {
Expand All @@ -166,6 +169,7 @@ void AvgPoolProjection::backward(const UpdateCallback& callback) {
1,
1,
confPaddingY_,
confPadding_);
confPadding_,
excludeMode_);
}
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/gserver/layers/PoolProjection.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class PoolProjection : public Projection {
int confPaddingY_, confPadding_;
size_t channels_;
std::string poolType_;
bool excludeMode_;

public:
PoolProjection(const ProjectionConfig& config,
Expand Down
16 changes: 15 additions & 1 deletion paddle/gserver/tests/test_LayerGrad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1211,14 +1211,18 @@ void setPoolConfig(TestConfig* config,
pool->set_output_y(oh);
}

void testPoolLayer(const string& poolType, bool trans, bool useGpu) {
void testPoolLayer(const string& poolType,
bool trans,
bool useGpu,
bool excludeMode = true) {
TestConfig config;
config.inputDefs.push_back({INPUT_DATA, "layer_0", 3136, 0});
LayerInputConfig* input = config.layerConfig.add_inputs();
PoolConfig* pool = input->mutable_pool_conf();

pool->set_img_size(14);
pool->set_img_size_y(14);
pool->set_exclude_mode(excludeMode);
setPoolConfig(&config, pool, poolType);
config.layerConfig.set_size(pool->output_x() * pool->output_y() *
pool->channels());
Expand Down Expand Up @@ -1250,16 +1254,26 @@ void testPoolLayer2(const string& poolType, bool trans, bool useGpu) {

TEST(Layer, PoolLayer) {
testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ false);
testPoolLayer("avg-projection",
/* trans= */ false,
/* useGpu= */ false,
/* excludeMode= */ false);
testPoolLayer("max-projection", /* trans= */ false, /* useGpu= */ false);
testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ false);

#ifdef PADDLE_WITH_CUDA
testPoolLayer("avg-projection", /* trans= */ false, /* useGpu= */ true);
testPoolLayer("avg-projection",
/* trans= */ false,
/* useGpu= */ true,
/* excludeMode= */ false);
testPoolLayer("max-projection", /* trans= */ false, /* useGpu= */ true);
testPoolLayer("cudnn-max-pool", /* trans= */ false, /* useGpu= */ true);
testPoolLayer("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true);
testPoolLayer2("cudnn-max-pool", /* trans= */ false, /* useGpu= */ true);
testPoolLayer2("cudnn-avg-pool", /* trans= */ false, /* useGpu= */ true);
testPoolLayer2(
"cudnn-avg-incl-pad-pool", /* trans= */ false, /* useGpu= */ true);
testPoolLayer("max-pool-with-mask", /* trans= */ false, /* useGpu= */ true);
#endif
}
Expand Down
24 changes: 16 additions & 8 deletions paddle/math/Matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,8 @@ void GpuMatrix::avgPoolForward(Matrix& inputMat,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW) {
size_t paddingW,
bool excludeMode) {
CHECK(inputMat.useGpu_ == true) << "Matrix type are not equal";

real* inputData = inputMat.getData();
Expand All @@ -1153,7 +1154,8 @@ void GpuMatrix::avgPoolForward(Matrix& inputMat,
paddingH,
paddingW,
data_,
getStride());
getStride(),
excludeMode);
}

void GpuMatrix::avgPoolBackward(Matrix& outGrad,
Expand All @@ -1168,7 +1170,8 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad,
real scaleTargets,
real scaleOutput,
size_t paddingH,
size_t paddingW) {
size_t paddingW,
bool excludeMode) {
CHECK(outGrad.useGpu_ == true) << "Matrix type are not equal";

real* outDiff = outGrad.getData();
Expand All @@ -1194,7 +1197,8 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad,
scaleTargets,
scaleOutput,
data_,
outGrad.getStride());
outGrad.getStride(),
excludeMode);
}

void GpuMatrix::maxPool3DForward(Matrix& inputMat,
Expand Down Expand Up @@ -2136,7 +2140,8 @@ void CpuMatrix::avgPoolForward(Matrix& input,
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW) {
size_t paddingW,
bool excludeMode) {
// The main loop
size_t num = input.getHeight();
size_t inLength = imgSizeH * imgSizeW;
Expand Down Expand Up @@ -2165,7 +2170,8 @@ void CpuMatrix::avgPoolForward(Matrix& input,
tgtData[ph * outputW + pw] += inData[h * imgSizeW + w];
}
}
int poolSize = (hend - hstart) * (wend - wstart);
int poolSize =
excludeMode ? (hend - hstart) * (wend - wstart) : sizeY * sizeX;
CHECK(poolSize);
tgtData[ph * outputW + pw] /= poolSize;
}
Expand All @@ -2189,7 +2195,8 @@ void CpuMatrix::avgPoolBackward(Matrix& input,
real scaleTargets,
real scaleOutput,
size_t paddingH,
size_t paddingW) {
size_t paddingW,
bool excludeMode) {
size_t num = input.getHeight();
size_t channels = input.getWidth() / outputH / outputW;
size_t inLength = imgSizeH * imgSizeW;
Expand All @@ -2211,7 +2218,8 @@ void CpuMatrix::avgPoolBackward(Matrix& input,
int wstart = pw * strideW - paddingW;
int wend = std::min(wstart + sizeX, imgSizeW);
wstart = std::max(wstart, 0);
int poolSize = (hend - hstart) * (wend - wstart);
int poolSize =
excludeMode ? (hend - hstart) * (wend - wstart) : sizeY * sizeX;
CHECK(poolSize);

for (int h = hstart; h < hend; ++h) {
Expand Down
19 changes: 13 additions & 6 deletions paddle/math/Matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,8 @@ class Matrix : public BaseMatrix {
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW) {
size_t paddingW,
bool excludeMode = true) {
LOG(FATAL) << "Not implemeted";
}

Expand All @@ -927,9 +928,11 @@ class Matrix : public BaseMatrix {
real scaleTargets,
real scaleOutput,
size_t paddingH,
size_t paddingW) {
size_t paddingW,
bool excludeMode = true) {
LOG(FATAL) << "Not implemeted";
}

/**
* Pooling 3D forward operation, pick out the largest element
* in the sizeX of value
Expand Down Expand Up @@ -1458,7 +1461,8 @@ class GpuMatrix : public Matrix {
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW);
size_t paddingW,
bool excludeMode = true);

void avgPoolBackward(Matrix& input,
size_t imgSizeH,
Expand All @@ -1472,7 +1476,8 @@ class GpuMatrix : public Matrix {
real scaleTargets,
real scaleOutput,
size_t paddingH,
size_t paddingW);
size_t paddingW,
bool excludeMode = true);

void maxPool3DForward(Matrix& inputMat,
Matrix& maxPoolIdx,
Expand Down Expand Up @@ -1730,7 +1735,8 @@ class CpuMatrix : public Matrix {
size_t outputH,
size_t outputW,
size_t paddingH,
size_t paddingW);
size_t paddingW,
bool excludeMode = true);

void avgPoolBackward(Matrix& input,
size_t imgSizeH,
Expand All @@ -1744,7 +1750,8 @@ class CpuMatrix : public Matrix {
real scaleTargets,
real scaleOutput,
size_t paddingH,
size_t paddingW);
size_t paddingW,
bool excludeMode = true);

void maxPool3DForward(Matrix& inputMat,
Matrix& maxPoolIdx,
Expand Down
2 changes: 2 additions & 0 deletions proto/ModelConfig.proto
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ message PoolConfig {
optional uint32 output_z = 16 [ default = 1 ];
optional uint32 img_size_z = 17 [ default = 1 ];
optional uint32 padding_z = 18 [ default = 1 ];

optional bool exclude_mode = 19;
}

message SppConfig {
Expand Down
Loading