Skip to content

Commit ce016c8

Browse files
Vijay Vasudevantensorflower-gardener
Vijay Vasudevan
authored andcommitted
Add initial support for NCHW format for DepthwiseConv2D (and separable_conv2d).
This is a straightforward port of the NHWC kernels with different input tensor indexing. This is a baseline for future optimizations, and allows running separable nets with NCHW format throughout. Change: 149565634
1 parent 96cb8f8 commit ce016c8

14 files changed

+832
-168
lines changed

tensorflow/core/framework/common_shape_fns.cc

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,26 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
332332
strides.size());
333333
}
334334

335+
string data_format;
336+
Status s = c->GetAttr("data_format", &data_format);
337+
int32 stride_rows;
338+
int32 stride_cols;
339+
if (s.ok() && data_format == "NCHW") {
340+
// Convert input shape to default NHWC for inference
341+
input_shape =
342+
c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2),
343+
c->Dim(input_shape, 3), c->Dim(input_shape, 1)}});
344+
stride_rows = strides[2];
345+
stride_cols = strides[3];
346+
} else {
347+
stride_rows = strides[1];
348+
stride_cols = strides[2];
349+
}
350+
335351
DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
336352
DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
337353
DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
354+
338355
DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
339356
DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
340357
DimensionHandle input_depth = c->Dim(filter_shape, 2);
@@ -350,9 +367,6 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
350367
Padding padding;
351368
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
352369

353-
const int32 stride_rows = strides[1];
354-
const int32 stride_cols = strides[2];
355-
356370
// TODO(mrry,shlens): Raise an error if the stride would cause
357371
// information in the input to be ignored. This will require a change
358372
// in the kernel implementation.
@@ -363,8 +377,14 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
363377
TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
364378
c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
365379

366-
ShapeHandle output_shape =
367-
c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
380+
ShapeHandle output_shape;
381+
if (data_format == "NCHW") {
382+
output_shape =
383+
c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols});
384+
} else {
385+
output_shape =
386+
c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
387+
}
368388
c->set_output(0, output_shape);
369389
return Status::OK();
370390
}

tensorflow/core/framework/common_shape_fns_test.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,7 @@ TEST(CommonShapeFnsTest, DepthwiseConv2DShapeTest) {
613613
.Input("filter", 0, DT_FLOAT)
614614
.Attr("strides", strides)
615615
.Attr("padding", "VALID")
616+
.Attr("data_format", "NHWC")
616617
.Finalize(&op.node_def));
617618

618619
// Most of DepthwiseConv2D is implicitly tested by Conv2D, so
@@ -634,6 +635,18 @@ TEST(CommonShapeFnsTest, DepthwiseConv2DShapeTest) {
634635
INFER_OK(op, "[1,2,2,3];[1,1,?,4]", "[d0_0,2,2,12]");
635636
INFER_OK(op, "[1,2,2,?];[1,1,?,4]", "[d0_0,2,2,?]");
636637
INFER_OK(op, "[1,2,2,3];[1,1,3,?]", "[d0_0,2,2,?]");
638+
639+
// Test for NCHW format.
640+
TF_CHECK_OK(NodeDefBuilder("test", "DepthwiseConv2dNative")
641+
.Input("input", 0, DT_FLOAT)
642+
.Input("filter", 0, DT_FLOAT)
643+
.Attr("strides", strides)
644+
.Attr("padding", "VALID")
645+
.Attr("data_format", "NCHW")
646+
.Finalize(&op.node_def));
647+
648+
// 1x1 filter, depth multiplication
649+
INFER_OK(op, "[1,3,2,2];[1,1,3,4]", "[d0_0,12,2,2]");
637650
}
638651

639652
TEST(CommonShapeFnsTest, AvgPool2DShapeTest) {

tensorflow/core/kernels/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2434,6 +2434,7 @@ tf_kernel_library(
24342434
name = "depthwise_conv_op",
24352435
prefix = "depthwise_conv_op",
24362436
deps = [
2437+
":bounds_check",
24372438
":conv_ops",
24382439
":ops_util",
24392440
"//tensorflow/core:core_cpu",
@@ -2450,6 +2451,7 @@ tf_kernel_library(
24502451
],
24512452
prefix = "depthwise_conv_grad_op",
24522453
deps = [
2454+
":bounds_check",
24532455
":ops_util",
24542456
"//tensorflow/core:core_cpu",
24552457
"//tensorflow/core:framework",

tensorflow/core/kernels/depthwise_conv_grad_op.cc

Lines changed: 96 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,14 @@ limitations under the License.
2525
#include "tensorflow/core/framework/tensor_shape.h"
2626
#include "tensorflow/core/framework/tensor_types.h"
2727
#include "tensorflow/core/framework/types.h"
28+
#include "tensorflow/core/kernels/bounds_check.h"
2829
#include "tensorflow/core/kernels/depthwise_conv_op.h"
2930
#include "tensorflow/core/kernels/ops_util.h"
3031
#include "tensorflow/core/lib/core/status.h"
3132
#include "tensorflow/core/platform/logging.h"
3233
#include "tensorflow/core/platform/types.h"
3334
#include "tensorflow/core/util/padding.h"
35+
#include "tensorflow/core/util/tensor_format.h"
3436
#include "tensorflow/core/util/work_sharder.h"
3537

3638
#if GOOGLE_CUDA
@@ -62,23 +64,51 @@ typedef Eigen::GpuDevice GPUDevice;
6264
context, batch == out_backprop.dim_size(0), \
6365
errors::InvalidArgument( \
6466
label, ": input and out_backprop must have the same batch size")); \
65-
const int64 input_rows = input_shape.dim_size(1); \
66-
const int64 input_cols = input_shape.dim_size(2); \
67+
const int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H'); \
68+
OP_REQUIRES( \
69+
context, \
70+
FastBoundsCheck(input_rows_raw, std::numeric_limits<int32>::max()), \
71+
errors::InvalidArgument("Input rows too large")); \
72+
const int32 input_rows = static_cast<int32>(input_rows_raw); \
73+
const int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W'); \
74+
OP_REQUIRES( \
75+
context, \
76+
FastBoundsCheck(input_cols_raw, std::numeric_limits<int32>::max()), \
77+
errors::InvalidArgument("Input cols too large")); \
78+
const int32 input_cols = static_cast<int32>(input_cols_raw); \
6779
const int64 filter_rows = filter_shape.dim_size(0); \
6880
const int64 filter_cols = filter_shape.dim_size(1); \
69-
const int64 output_rows = out_backprop.dim_size(1); \
70-
const int64 output_cols = out_backprop.dim_size(2); \
71-
const int64 in_depth = input_shape.dim_size(3); \
81+
const int64 output_rows_raw = \
82+
GetTensorDim(out_backprop.shape(), data_format_, 'H'); \
83+
OP_REQUIRES( \
84+
context, \
85+
FastBoundsCheck(output_rows_raw, std::numeric_limits<int32>::max()), \
86+
errors::InvalidArgument("Output rows too large")); \
87+
const int32 output_rows = static_cast<int32>(output_rows_raw); \
88+
const int64 output_cols_raw = \
89+
GetTensorDim(out_backprop.shape(), data_format_, 'W'); \
90+
OP_REQUIRES( \
91+
context, \
92+
FastBoundsCheck(output_cols_raw, std::numeric_limits<int32>::max()), \
93+
errors::InvalidArgument("Output cols too large")); \
94+
const int32 output_cols = static_cast<int32>(output_cols_raw); \
95+
const int64 in_depth = GetTensorDim(input_shape, data_format_, 'C'); \
7296
OP_REQUIRES(context, in_depth == filter_shape.dim_size(2), \
7397
errors::InvalidArgument( \
7498
label, ": input and filter must have the same in_depth")); \
7599
const int64 depth_multiplier = filter_shape.dim_size(3); \
76-
const int64 out_depth = out_backprop.dim_size(3); \
100+
const int64 out_depth_raw = \
101+
GetTensorDim(out_backprop.shape(), data_format_, 'C'); \
102+
OP_REQUIRES( \
103+
context, \
104+
FastBoundsCheck(out_depth_raw, std::numeric_limits<int32>::max()), \
105+
errors::InvalidArgument("Output depth too large")); \
106+
const int32 out_depth = static_cast<int32>(out_depth_raw); \
77107
OP_REQUIRES( \
78108
context, (depth_multiplier * in_depth) == out_depth, \
79109
errors::InvalidArgument( \
80110
label, ": depth_multiplier * in_depth not equal to out_depth")); \
81-
const auto stride = strides_[1]; \
111+
const auto stride = stride_; \
82112
int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; \
83113
OP_REQUIRES_OK(context, \
84114
GetWindowedOutputSize(input_rows, filter_rows, stride, \
@@ -343,7 +373,12 @@ struct LaunchDepthwiseConvBackpropInputOp<CPUDevice, T> {
343373
344374
static void launch(OpKernelContext* ctx, const DepthwiseArgs& args,
345375
const T* out_backprop, const T* depthwise_filter,
346-
T* in_backprop) {
376+
T* in_backprop, TensorFormat data_format) {
377+
OP_REQUIRES(
378+
ctx, data_format == FORMAT_NHWC,
379+
errors::Unimplemented(
380+
"Depthwise convolution on CPU is only supported for NHWC format"));
381+
347382
static const int64 kPacketSize = (sizeof(Packet) / sizeof(T));
348383
349384
// Pad 'depthwise_filter' to vector register width (if needed).
@@ -482,16 +517,18 @@ static void DepthwiseConvBackpropInputReference(const DepthwiseArgs& args,
482517
template <typename T>
483518
struct DepthwiseConv2dBackpropInputGPULaunch {
484519
static void Run(const GPUDevice& d, const DepthwiseArgs args,
485-
const T* out_backprop, const T* filter, T* in_backprop);
520+
const T* out_backprop, const T* filter, T* in_backprop,
521+
TensorFormat data_format);
486522
};
487523
488524
template <typename T>
489525
struct LaunchDepthwiseConvBackpropInputOp<GPUDevice, T> {
490526
static void launch(OpKernelContext* ctx, const DepthwiseArgs args,
491-
const T* out_backprop, const T* filter, T* in_backprop) {
527+
const T* out_backprop, const T* filter, T* in_backprop,
528+
TensorFormat data_format) {
492529
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
493-
DepthwiseConv2dBackpropInputGPULaunch<T>().Run(d, args, out_backprop,
494-
filter, in_backprop);
530+
DepthwiseConv2dBackpropInputGPULaunch<T>().Run(
531+
d, args, out_backprop, filter, in_backprop, data_format);
495532
auto stream = ctx->op_device_context()->stream();
496533
OP_REQUIRES(ctx, stream->ok(), errors::Internal("Launch of gpu kernel for "
497534
"DepthwiseConv2dBackpropInp"
@@ -511,12 +548,23 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
511548
OP_REQUIRES(context, strides_.size() == 4,
512549
errors::InvalidArgument("Sliding window strides field must "
513550
"specify 4 dimensions"));
514-
OP_REQUIRES(context, strides_[1] == strides_[2],
551+
552+
string data_format;
553+
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
554+
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
555+
errors::InvalidArgument("Invalid data format"));
556+
557+
stride_ = GetTensorDim(strides_, data_format_, 'H');
558+
const int64 stride_w = GetTensorDim(strides_, data_format_, 'W');
559+
const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
560+
const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
561+
562+
OP_REQUIRES(context, stride_ == stride_w,
515563
errors::InvalidArgument(
516564
"Current implementation only supports equal length "
517565
"strides in the row and column dimensions."));
518566
OP_REQUIRES(
519-
context, (strides_[0] == 1 && strides_[3] == 1),
567+
context, (stride_n == 1 && stride_c == 1),
520568
errors::InvalidArgument("Current implementation does not yet support "
521569
"strides in the batch and depth dimensions."));
522570
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
@@ -539,7 +587,6 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
539587
input_shape.AddDim(in_sizes_data[i]);
540588
}
541589
const TensorShape& filter_shape = filter.shape();
542-
543590
EXTRACT_AND_VERIFY_DIMENSIONS("DepthwiseConv2DBackpropInput");
544591
Tensor* in_backprop = nullptr;
545592
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
@@ -552,12 +599,15 @@ class DepthwiseConv2dNativeBackpropInputOp : public OpKernel {
552599
return;
553600
}
554601
LaunchDepthwiseConvBackpropInputOp<Device, T>::launch(
555-
context, args, out_backprop_ptr, filter_ptr, in_backprop_ptr);
602+
context, args, out_backprop_ptr, filter_ptr, in_backprop_ptr,
603+
data_format_);
556604
}
557605
558606
private:
559607
std::vector<int32> strides_;
560608
Padding padding_;
609+
TensorFormat data_format_;
610+
int64 stride_;
561611
562612
TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeBackpropInputOp);
563613
};
@@ -695,8 +745,13 @@ struct LaunchDepthwiseConvBackpropFilterOp<CPUDevice, T> {
695745
typedef typename Eigen::internal::packet_traits<T>::type Packet;
696746
697747
static void launch(OpKernelContext* ctx, const DepthwiseArgs& args,
698-
const T* out_backprop, const T* input,
699-
T* filter_backprop) {
748+
const T* out_backprop, const T* input, T* filter_backprop,
749+
TensorFormat data_format) {
750+
OP_REQUIRES(
751+
ctx, data_format == FORMAT_NHWC,
752+
errors::Unimplemented(
753+
"Depthwise convolution on CPU is only supported for NHWC format"));
754+
700755
static const int64 kPacketSize = (sizeof(Packet) / sizeof(T));
701756
702757
const int64 filter_spatial_size = args.filter_rows * args.filter_cols;
@@ -855,14 +910,15 @@ static void DepthwiseConvBackpropFilterReference(const DepthwiseArgs& args,
855910
template <typename T>
856911
struct DepthwiseConv2dBackpropFilterGPULaunch {
857912
static void Run(const GPUDevice& d, const DepthwiseArgs args,
858-
const T* out_backprop, const T* input, T* filter_backprop);
913+
const T* out_backprop, const T* input, T* filter_backprop,
914+
TensorFormat data_format);
859915
};
860916
861917
template <typename T>
862918
struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, T> {
863919
static void launch(OpKernelContext* ctx, const DepthwiseArgs args,
864-
const T* out_backprop, const T* input,
865-
T* filter_backprop) {
920+
const T* out_backprop, const T* input, T* filter_backprop,
921+
TensorFormat data_format) {
866922
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
867923
auto stream = ctx->op_device_context()->stream();
868924
@@ -873,8 +929,8 @@ struct LaunchDepthwiseConvBackpropFilterOp<GPUDevice, T> {
873929
num_filter_backprop);
874930
stream->ThenMemset32(&filter_bp_ptr, 0, num_filter_backprop * sizeof(T));
875931
876-
DepthwiseConv2dBackpropFilterGPULaunch<T>().Run(d, args, out_backprop,
877-
input, filter_backprop);
932+
DepthwiseConv2dBackpropFilterGPULaunch<T>().Run(
933+
d, args, out_backprop, input, filter_backprop, data_format);
878934
OP_REQUIRES(ctx, stream->ok(), errors::Internal("Launch of gpu kernel for "
879935
"DepthwiseConv2dBackpropFil"
880936
"terGPULaunch failed"));
@@ -893,12 +949,23 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
893949
OP_REQUIRES(context, strides_.size() == 4,
894950
errors::InvalidArgument("Sliding window strides field must "
895951
"specify 4 dimensions"));
896-
OP_REQUIRES(context, strides_[1] == strides_[2],
952+
953+
string data_format;
954+
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
955+
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
956+
errors::InvalidArgument("Invalid data format"));
957+
958+
stride_ = GetTensorDim(strides_, data_format_, 'H');
959+
const int64 stride_w = GetTensorDim(strides_, data_format_, 'W');
960+
const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
961+
const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
962+
963+
OP_REQUIRES(context, stride_ == stride_w,
897964
errors::InvalidArgument(
898965
"Current implementation only supports equal length "
899966
"strides in the row and column dimensions."));
900967
OP_REQUIRES(
901-
context, (strides_[0] == 1 && strides_[3] == 1),
968+
context, (stride_n == 1 && stride_c == 1),
902969
errors::InvalidArgument("Current implementation does not yet support "
903970
"strides in the batch and depth dimensions."));
904971
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
@@ -935,12 +1002,15 @@ class DepthwiseConv2dNativeBackpropFilterOp : public OpKernel {
9351002
return;
9361003
}
9371004
LaunchDepthwiseConvBackpropFilterOp<Device, T>::launch(
938-
context, args, out_backprop_ptr, input_ptr, filter_backprop_ptr);
1005+
context, args, out_backprop_ptr, input_ptr, filter_backprop_ptr,
1006+
data_format_);
9391007
}
9401008
9411009
private:
9421010
std::vector<int32> strides_;
9431011
Padding padding_;
1012+
TensorFormat data_format_;
1013+
int64 stride_;
9441014
9451015
TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeBackpropFilterOp);
9461016
};

0 commit comments

Comments
 (0)