Skip to content

Commit 749065c

Browse files
committed
Merge pull request tensorflow#2137 from zheng-xq/conv_grad_stride_ksize
Issue 2066: Fix the conv for stride > ksize case.
2 parents 5e22e3a + 8ec95eb commit 749065c

File tree

2 files changed

+40
-12
lines changed

2 files changed

+40
-12
lines changed

tensorflow/core/kernels/conv_grad_ops.cc

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#define USE_EIGEN_TENSOR
1919
#define EIGEN_USE_THREADS
2020

21+
#include <algorithm>
2122
#include <vector>
2223
#include "tensorflow/core/framework/numeric_op.h"
2324
#include "tensorflow/core/framework/op_kernel.h"
@@ -838,13 +839,13 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
838839
context->allocate_output(0, input_shape, &in_backprop));
839840

840841
const int padding_rows =
841-
(padding_ == VALID)
842-
? 0
843-
: (output_rows - 1) * stride_rows + filter_rows - input_rows;
842+
(padding_ == VALID) ? 0
843+
: std::max<int>(0, (output_rows - 1) * stride_rows +
844+
filter_rows - input_rows);
844845
const int padding_cols =
845-
(padding_ == VALID)
846-
? 0
847-
: (output_cols - 1) * stride_cols + filter_cols - input_cols;
846+
(padding_ == VALID) ? 0
847+
: std::max<int>(0, (output_cols - 1) * stride_cols +
848+
filter_cols - input_cols);
848849

849850
// TODO(keveman): cuDNN only supports equal padding on both sides, so only
850851
// calling it when that is true. Remove this check when (if?) cuDNN starts
@@ -1137,13 +1138,13 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
11371138
context->allocate_output(0, filter_shape, &filter_backprop));
11381139

11391140
const int padding_rows =
1140-
(padding_ == VALID)
1141-
? 0
1142-
: (output_rows - 1) * stride_rows + filter_rows - input_rows;
1141+
(padding_ == VALID) ? 0
1142+
: std::max<int>(0, (output_rows - 1) * stride_rows +
1143+
filter_rows - input_rows);
11431144
const int padding_cols =
1144-
(padding_ == VALID)
1145-
? 0
1146-
: (output_cols - 1) * stride_cols + filter_cols - input_cols;
1145+
(padding_ == VALID) ? 0
1146+
: std::max<int>(0, (output_cols - 1) * stride_cols +
1147+
filter_cols - input_cols);
11471148

11481149
// TODO(zhengxq): cuDNN only supports equal padding on both sides, so only
11491150
// calling it when that is true. Remove this check when (if?) cuDNN starts

tensorflow/python/kernel_tests/conv_ops_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,21 @@ def testConv2D2x2Depth3ValidBackpropInputStride1x2(self):
465465
data_format=data_format,
466466
use_gpu=use_gpu)
467467

468+
def testConv2DStrideTwoFilterOneSameBackpropInput(self):
469+
expected_output = [1.0, 0.0, 2.0, 0.0,
470+
0.0, 0.0, 0.0, 0.0,
471+
3.0, 0.0, 4.0, 0.0,
472+
0.0, 0.0, 0.0, 0.0]
473+
for (data_format, use_gpu) in GetTestConfigs():
474+
self._RunAndVerifyBackpropInput(input_sizes=[1, 4, 4, 1],
475+
filter_sizes=[1, 1, 1, 1],
476+
output_sizes=[1, 2, 2, 1],
477+
strides=[2, 2],
478+
padding="SAME",
479+
expected=expected_output,
480+
data_format=data_format,
481+
use_gpu=use_gpu)
482+
468483
# Testing for backprops
469484
def _RunAndVerifyBackpropFilter(self, input_sizes, filter_sizes, output_sizes,
470485
strides, padding, expected, data_format,
@@ -568,6 +583,18 @@ def testConv2D2x2Depth3ValidBackpropFilterStride1x2(self):
568583
data_format=data_format,
569584
use_gpu=use_gpu)
570585

586+
def testConv2DStrideTwoFilterOneSameBackpropFilter(self):
587+
expected_output = [78.]
588+
for (data_format, use_gpu) in GetTestConfigs():
589+
self._RunAndVerifyBackpropFilter(input_sizes=[1, 4, 4, 1],
590+
filter_sizes=[1, 1, 1, 1],
591+
output_sizes=[1, 2, 2, 1],
592+
strides=[2, 2],
593+
padding="SAME",
594+
expected=expected_output,
595+
data_format=data_format,
596+
use_gpu=use_gpu)
597+
571598
# Gradient checkers
572599
def ConstructAndTestGradient(self, batch, input_rows, input_cols, filter_rows,
573600
filter_cols, in_depth, out_depth, stride_rows,

0 commit comments

Comments
 (0)