@@ -122,6 +122,70 @@ def _impl(inputs, attr, params):
122122 return get_relay_op (name )(* inputs )
123123 return _impl
124124
125+ def _pool3d (name ):
126+ def _impl (inputs , attr , params ):
127+ attr ['data_format' ] = attr ['data_format' ].decode ("utf-8" )
128+ flip_layout = False
129+
130+ input_shape = attr ['_input_shapes' ][inputs [0 ]]
131+
132+ if attr ['data_format' ] == 'NDHWC' :
133+ attr ['kernel_shape' ] = (attr ['ksize' ][1 ], attr ['ksize' ][2 ], attr ['ksize' ][3 ])
134+ attr ['strides' ] = (attr ['strides' ][1 ], attr ['strides' ][2 ], attr ['strides' ][3 ])
135+ elif attr ['data_format' ] == 'NCDHW' :
136+ attr ['kernel_shape' ] = (attr ['ksize' ][2 ], attr ['ksize' ][3 ], attr ['ksize' ][4 ])
137+ attr ['strides' ] = (attr ['strides' ][2 ], attr ['strides' ][3 ], attr ['strides' ][4 ])
138+ else :
139+ msg = 'Value {} of attribute "data_format" of operator Pooling ' \
140+ 'is not valid.'
141+ raise tvm .error .OpAttributeInvalid (msg .format (attr ['data_format' ]))
142+ if attr ['data_format' ] == "NDHWC" :
143+ input_shape = [attr ['_input_shapes' ][inputs [0 ]][i ] for i in (0 , 4 , 1 , 2 , 3 )]
144+ inputs [0 ] = _op .transpose (inputs [0 ], axes = (0 , 4 , 1 , 2 , 3 ))
145+ attr ['data_format' ] = "NCDHW"
146+ attr ['_input_shapes' ][inputs [0 ]] = input_shape
147+ flip_layout = True
148+
149+ attr ['padding' ] = attr ['padding' ].decode ("utf-8" )
150+
151+ if attr ['padding' ] == 'VALID' :
152+ attr ['padding' ] = [0 , 0 , 0 , 0 , 0 , 0 ]
153+ elif attr ['padding' ] == 'SAME' :
154+ stride_d , stride_h , stride_w = attr ['strides' ]
155+ kernel_d , kernel_h , kernel_w = attr ['kernel_shape' ]
156+ if attr ['data_format' ] == 'NDHWC' :
157+ in_d = input_shape [1 ]
158+ in_h = input_shape [2 ]
159+ in_w = input_shape [3 ]
160+ else :
161+ in_d = input_shape [2 ]
162+ in_h = input_shape [3 ]
163+ in_w = input_shape [4 ]
164+ pad_d = _get_pad_pair (in_d , kernel_d , stride_d )
165+ pad_v = _get_pad_pair (in_h , kernel_h , stride_h )
166+ pad_h = _get_pad_pair (in_w , kernel_w , stride_w )
167+
168+ attr ['padding' ] = [pad_d [0 ], pad_v [0 ], pad_h [0 ], pad_d [1 ], pad_v [1 ], pad_h [1 ]]
169+ else :
170+ msg = 'Value {} in attribute "padding" of operator Pooling is ' \
171+ 'not valid.'
172+ raise tvm .error .OpAttributeInvalid (msg .format (attr ['padding' ]))
173+
174+ if name == "avg_pool" :
175+ attr ['count_include_pad' ] = False
176+ attr ['ceil_mode' ] = False
177+ out = AttrCvt (
178+ op_name = name ,
179+ transforms = {
180+ 'kernel_shape' : 'pool_size' ,
181+ 'data_format' : 'layout' },
182+ ignores = ['ksize' ])(inputs , attr )
183+ if flip_layout :
184+ out = _op .transpose (out , axes = (0 , 2 , 3 , 4 , 1 ))
185+ return out
186+
187+ return _impl
188+
125189def _pooling (name ):
126190 def _impl (inputs , attr , params ):
127191
@@ -1409,6 +1473,7 @@ def _impl(inputs, attr, params):
14091473 'ArgMin' : _argx (_op .argmin , 'argmin' ),
14101474 'Assert' : _assert (),
14111475 'AvgPool' : _pooling ('avg_pool' ),
1476+ 'AvgPool3D' : _pool3d ('avg_pool3d' ),
14121477 'BatchMatMul' : _batch_matmul (),
14131478 'BatchMatMulV2' : _batch_matmul (),
14141479 'BatchNormWithGlobalNormalization' : _batch_norm (),
@@ -1460,6 +1525,7 @@ def _impl(inputs, attr, params):
14601525 'MatMul' : _matmul (),
14611526 'Max' : _reduce ('max' ),
14621527 'MaxPool' : _pooling ('max_pool' ),
1528+ 'MaxPool3D' : _pool3d ('max_pool3d' ),
14631529 'Maximum' : _elemwise ('maximum' ),
14641530 'Mean' : _mean (),
14651531 'Min' : _reduce ('min' ),
0 commit comments