@@ -714,10 +714,10 @@ function depthwiseConv2dDerFilter<T extends Tensor3D|Tensor4D>(
714
714
* - For more info, see this guide:
715
715
* [https://www.tensorflow.org/api_guides/python/nn#Convolution](
716
716
* https://www.tensorflow.org/api_guides/python/nn#Convolution)
717
- * @param dataFormat: An optional string from: "NHWC ", "NCHW ". Defaults to
718
- * "NHWC ". Specify the data format of the input and output data. With the
719
- * default format "NHWC ", the data is stored in the order of: [batch,
720
- * depth, height, width, channels]. Only "NHWC " is currently supported.
717
+ * @param dataFormat: An optional string from: "NDHWC ", "NCDHW ". Defaults to
718
+ * "NDHWC ". Specify the data format of the input and output data. With the
719
+ * default format "NDHWC ", the data is stored in the order of: [batch,
720
+ * depth, height, width, channels]. Only "NDHWC " is currently supported.
721
721
* @param dilations The dilation rates: `[dilationDepth, dilationHeight,
722
722
* dilationWidth]` in which we sample input values across the height
723
723
* and width dimensions in atrous convolution. Defaults to `[1, 1, 1]`.
@@ -730,7 +730,7 @@ function depthwiseConv2dDerFilter<T extends Tensor3D|Tensor4D>(
730
730
function conv3d_ < T extends Tensor4D | Tensor5D > (
731
731
x : T | TensorLike , filter : Tensor5D | TensorLike ,
732
732
strides : [ number , number , number ] | number , pad : 'valid' | 'same' ,
733
- dataFormat : 'NHWC ' | 'NCHW ' = 'NHWC ' ,
733
+ dataFormat : 'NDHWC ' | 'NCDHW ' = 'NDHWC ' ,
734
734
dilations : [ number , number , number ] | number = [ 1 , 1 , 1 ] ) : T {
735
735
const $x = convertToTensor ( x , 'x' , 'conv3d' ) ;
736
736
const $filter = convertToTensor ( filter , 'filter' , 'conv3d' ) ;
@@ -758,9 +758,9 @@ function conv3d_<T extends Tensor4D|Tensor5D>(
758
758
( ) => 'Error in conv3D: Either strides or dilations must be 1. ' +
759
759
`Got strides ${ strides } and dilations '${ dilations } '` ) ;
760
760
util . assert (
761
- dataFormat === 'NHWC ' ,
761
+ dataFormat === 'NDHWC ' ,
762
762
( ) => `Error in conv3d: got dataFormat of ${
763
- dataFormat } but only NHWC is currently supported.`) ;
763
+ dataFormat } but only NDHWC is currently supported.`) ;
764
764
765
765
const convInfo = conv_util . computeConv3DInfo (
766
766
x5D . shape , $filter . shape , strides , dilations , pad ) ;
0 commit comments