Skip to content

Commit ef1e101

Browse files
pvaneckNikhil Thorat
authored andcommitted
Change conv3d expected data formats (tensorflow#1620)
BUG The currently listed data formats apply to conv2d, so update the data formats for conv3d. NHWC -> NDHWC NCHW -> NCDHW
1 parent 9b9f5fd commit ef1e101

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

src/ops/conv.ts

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -714,10 +714,10 @@ function depthwiseConv2dDerFilter<T extends Tensor3D|Tensor4D>(
714714
* - For more info, see this guide:
715715
* [https://www.tensorflow.org/api_guides/python/nn#Convolution](
716716
* https://www.tensorflow.org/api_guides/python/nn#Convolution)
717-
* @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
718-
* "NHWC". Specify the data format of the input and output data. With the
719-
* default format "NHWC", the data is stored in the order of: [batch,
720-
* depth, height, width, channels]. Only "NHWC" is currently supported.
717+
* @param dataFormat: An optional string from: "NDHWC", "NCDHW". Defaults to
718+
* "NDHWC". Specify the data format of the input and output data. With the
719+
* default format "NDHWC", the data is stored in the order of: [batch,
720+
* depth, height, width, channels]. Only "NDHWC" is currently supported.
721721
* @param dilations The dilation rates: `[dilationDepth, dilationHeight,
722722
* dilationWidth]` in which we sample input values across the height
723723
* and width dimensions in atrous convolution. Defaults to `[1, 1, 1]`.
@@ -730,7 +730,7 @@ function depthwiseConv2dDerFilter<T extends Tensor3D|Tensor4D>(
730730
function conv3d_<T extends Tensor4D|Tensor5D>(
731731
x: T|TensorLike, filter: Tensor5D|TensorLike,
732732
strides: [number, number, number]|number, pad: 'valid'|'same',
733-
dataFormat: 'NHWC'|'NCHW' = 'NHWC',
733+
dataFormat: 'NDHWC'|'NCDHW' = 'NDHWC',
734734
dilations: [number, number, number]|number = [1, 1, 1]): T {
735735
const $x = convertToTensor(x, 'x', 'conv3d');
736736
const $filter = convertToTensor(filter, 'filter', 'conv3d');
@@ -758,9 +758,9 @@ function conv3d_<T extends Tensor4D|Tensor5D>(
758758
() => 'Error in conv3D: Either strides or dilations must be 1. ' +
759759
`Got strides ${strides} and dilations '${dilations}'`);
760760
util.assert(
761-
dataFormat === 'NHWC',
761+
dataFormat === 'NDHWC',
762762
() => `Error in conv3d: got dataFormat of ${
763-
dataFormat} but only NHWC is currently supported.`);
763+
dataFormat} but only NDHWC is currently supported.`);
764764

765765
const convInfo = conv_util.computeConv3DInfo(
766766
x5D.shape, $filter.shape, strides, dilations, pad);

src/ops/conv3d_test.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,4 +468,19 @@ describeWithFlags('conv3d', ALL_ENVS, () => {
468468
const result = tf.conv3d(x, w, stride, pad);
469469
expectArraysClose(result, [2, 4, 6, 8]);
470470
});
471+
472+
it('throws when data format not NDHWC', () => {
473+
const inputDepth = 1;
474+
const outputDepth = 1;
475+
const inputShape: [number, number, number, number] = [2, 2, 1, inputDepth];
476+
const pad = 'valid';
477+
const fSize = 1;
478+
const stride = 1;
479+
const dataFormat = 'NCDHW';
480+
481+
const x = tf.tensor4d([1, 2, 3, 4], inputShape);
482+
const w = tf.tensor5d([2], [fSize, fSize, fSize, inputDepth, outputDepth]);
483+
484+
expect(() => tf.conv3d(x, w, stride, pad, dataFormat)).toThrowError();
485+
});
471486
});

0 commit comments

Comments
 (0)