Skip to content

Commit 67f16f6

Browse files
yongwwwzhiics
authored andcommitted
[Relay] add max_pool3d in relay and TF converter (apache#4551)
* [Relay] add max_pool3d in relay and TF converter * fix comments
1 parent 70a34af commit 67f16f6

File tree

11 files changed

+382
-62
lines changed

11 files changed

+382
-62
lines changed

docs/langref/relay_op.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ This level enables typical convnet models.
7171
tvm.relay.nn.conv2d_transpose
7272
tvm.relay.nn.dense
7373
tvm.relay.nn.max_pool2d
74+
tvm.relay.nn.max_pool3d
7475
tvm.relay.nn.avg_pool2d
76+
tvm.relay.nn.avg_pool3d
7577
tvm.relay.nn.global_max_pool2d
7678
tvm.relay.nn.global_avg_pool2d
7779
tvm.relay.nn.upsampling
@@ -246,7 +248,9 @@ Level 2 Definitions
246248
.. autofunction:: tvm.relay.nn.conv2d_transpose
247249
.. autofunction:: tvm.relay.nn.dense
248250
.. autofunction:: tvm.relay.nn.max_pool2d
251+
.. autofunction:: tvm.relay.nn.max_pool3d
249252
.. autofunction:: tvm.relay.nn.avg_pool2d
253+
.. autofunction:: tvm.relay.nn.avg_pool3d
250254
.. autofunction:: tvm.relay.nn.global_max_pool2d
251255
.. autofunction:: tvm.relay.nn.global_avg_pool2d
252256
.. autofunction:: tvm.relay.nn.upsampling

python/tvm/relay/_parser.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,10 @@ def __call__(self, args, attrs, type_args):
135135
"nn.dense": op.nn.dense,
136136
"nn.bias_add": op.nn.bias_add,
137137
"nn.max_pool2d": op.nn.max_pool2d,
138+
"nn.max_pool3d": op.nn.max_pool3d,
138139
"nn.global_max_pool2d": op.nn.global_max_pool2d,
139140
"nn.avg_pool2d": op.nn.avg_pool2d,
141+
"nn.avg_pool3d": op.nn.avg_pool3d,
140142
"nn.global_avg_pool2d": op.nn.global_avg_pool2d,
141143
"nn.softmax": op.nn.softmax,
142144
"reshape": op.reshape,

python/tvm/relay/frontend/tensorflow.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,70 @@ def _impl(inputs, attr, params):
122122
return get_relay_op(name)(*inputs)
123123
return _impl
124124

125+
def _pool3d(name):
126+
def _impl(inputs, attr, params):
127+
attr['data_format'] = attr['data_format'].decode("utf-8")
128+
flip_layout = False
129+
130+
input_shape = attr['_input_shapes'][inputs[0]]
131+
132+
if attr['data_format'] == 'NDHWC':
133+
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2], attr['ksize'][3])
134+
attr['strides'] = (attr['strides'][1], attr['strides'][2], attr['strides'][3])
135+
elif attr['data_format'] == 'NCDHW':
136+
attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3], attr['ksize'][4])
137+
attr['strides'] = (attr['strides'][2], attr['strides'][3], attr['strides'][4])
138+
else:
139+
msg = 'Value {} of attribute "data_format" of operator Pooling ' \
140+
'is not valid.'
141+
raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
142+
if attr['data_format'] == "NDHWC":
143+
input_shape = [attr['_input_shapes'][inputs[0]][i] for i in (0, 4, 1, 2, 3)]
144+
inputs[0] = _op.transpose(inputs[0], axes=(0, 4, 1, 2, 3))
145+
attr['data_format'] = "NCDHW"
146+
attr['_input_shapes'][inputs[0]] = input_shape
147+
flip_layout = True
148+
149+
attr['padding'] = attr['padding'].decode("utf-8")
150+
151+
if attr['padding'] == 'VALID':
152+
attr['padding'] = [0, 0, 0, 0, 0, 0]
153+
elif attr['padding'] == 'SAME':
154+
stride_d, stride_h, stride_w = attr['strides']
155+
kernel_d, kernel_h, kernel_w = attr['kernel_shape']
156+
if attr['data_format'] == 'NDHWC':
157+
in_d = input_shape[1]
158+
in_h = input_shape[2]
159+
in_w = input_shape[3]
160+
else:
161+
in_d = input_shape[2]
162+
in_h = input_shape[3]
163+
in_w = input_shape[4]
164+
pad_d = _get_pad_pair(in_d, kernel_d, stride_d)
165+
pad_v = _get_pad_pair(in_h, kernel_h, stride_h)
166+
pad_h = _get_pad_pair(in_w, kernel_w, stride_w)
167+
168+
attr['padding'] = [pad_d[0], pad_v[0], pad_h[0], pad_d[1], pad_v[1], pad_h[1]]
169+
else:
170+
msg = 'Value {} in attribute "padding" of operator Pooling is ' \
171+
'not valid.'
172+
raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))
173+
174+
if name == "avg_pool":
175+
attr['count_include_pad'] = False
176+
attr['ceil_mode'] = False
177+
out = AttrCvt(
178+
op_name=name,
179+
transforms={
180+
'kernel_shape': 'pool_size',
181+
'data_format': 'layout'},
182+
ignores=['ksize'])(inputs, attr)
183+
if flip_layout:
184+
out = _op.transpose(out, axes=(0, 2, 3, 4, 1))
185+
return out
186+
187+
return _impl
188+
125189
def _pooling(name):
126190
def _impl(inputs, attr, params):
127191

@@ -1409,6 +1473,7 @@ def _impl(inputs, attr, params):
14091473
'ArgMin' : _argx(_op.argmin, 'argmin'),
14101474
'Assert' : _assert(),
14111475
'AvgPool' : _pooling('avg_pool'),
1476+
'AvgPool3D' : _pool3d('avg_pool3d'),
14121477
'BatchMatMul' : _batch_matmul(),
14131478
'BatchMatMulV2' : _batch_matmul(),
14141479
'BatchNormWithGlobalNormalization' : _batch_norm(),
@@ -1460,6 +1525,7 @@ def _impl(inputs, attr, params):
14601525
'MatMul' : _matmul(),
14611526
'Max' : _reduce('max'),
14621527
'MaxPool' : _pooling('max_pool'),
1528+
'MaxPool3D' : _pool3d('max_pool3d'),
14631529
'Maximum' : _elemwise('maximum'),
14641530
'Mean' : _mean(),
14651531
'Min' : _reduce('min'),

python/tvm/relay/op/nn/_nn.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,18 @@ def schedule_max_pool2d(attrs, outs, target):
396396
reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
397397

398398

399+
# max_pool3d
400+
@reg.register_schedule("nn.max_pool3d")
401+
def schedule_max_pool3d(attrs, outs, target):
402+
"""Schedule definition of max_pool3d"""
403+
layout = attrs.layout
404+
with target:
405+
return topi.generic.schedule_pool(outs, layout)
406+
407+
408+
reg.register_pattern("nn.max_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)
409+
410+
399411
# avg_pool2d
400412
@reg.register_schedule("nn.avg_pool2d")
401413
def schedule_avg_pool2d(attrs, outs, target):
@@ -404,10 +416,21 @@ def schedule_avg_pool2d(attrs, outs, target):
404416
with target:
405417
return topi.generic.schedule_pool(outs, layout)
406418

407-
408419
reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
409420

410421

422+
# avg_pool3d
423+
@reg.register_schedule("nn.avg_pool3d")
424+
def schedule_avg_pool3d(attrs, outs, target):
425+
"""Schedule definition of avg_pool3d"""
426+
layout = attrs.layout
427+
with target:
428+
return topi.generic.schedule_pool(outs, layout)
429+
430+
431+
reg.register_pattern("nn.avg_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)
432+
433+
411434
# max_pool2d_grad
412435
@reg.register_schedule("nn.max_pool2d_grad")
413436
def schedule_max_pool2d_grad(attrs, outs, target):

python/tvm/relay/op/nn/nn.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,51 @@ def max_pool2d(data,
425425
return _make.max_pool2d(data, pool_size, strides, padding,
426426
layout, ceil_mode)
427427

428+
def max_pool3d(data,
429+
pool_size=(1, 1, 1),
430+
strides=(1, 1, 1),
431+
padding=(0, 0, 0),
432+
layout="NCDHW",
433+
ceil_mode=False):
434+
r"""3D maximum pooling operator.
435+
436+
This operator takes data as input and does 3D max value calculation
437+
with in pool_size sized window by striding defined by stride.
438+
439+
440+
In the default case, where the data_layout is `NCDHW`
441+
a data Tensor with shape `(batch_size, channels, depth, height, width)`,
442+
to produce an output Tensor.
443+
444+
The ceil_mode is used to take ceil or floor while computing out shape.
445+
count_include_pad indicates including or excluding padded input values in computation.
446+
This operator accepts data layout specification.
447+
448+
Parameters
449+
----------
450+
data : tvm.relay.Expr
451+
The input data to the operator.
452+
453+
strides : tuple of int, optional
454+
The strides of pooling.
455+
456+
padding : tuple of int, optional
457+
The padding for pooling.
458+
459+
layout : str, optional
460+
Layout of the input.
461+
462+
ceil_mode : bool, optional
463+
To enable or disable ceil while pooling.
464+
465+
Returns
466+
-------
467+
result : tvm.relay.Expr
468+
The computed result.
469+
"""
470+
return _make.max_pool3d(data, pool_size, strides, padding,
471+
layout, ceil_mode)
472+
428473
def avg_pool2d(data,
429474
pool_size=(1, 1),
430475
strides=(1, 1),
@@ -482,6 +527,55 @@ def avg_pool2d(data,
482527
return _make.avg_pool2d(data, pool_size, strides, padding,
483528
layout, ceil_mode, count_include_pad)
484529

530+
def avg_pool3d(data,
531+
pool_size=(1, 1, 1),
532+
strides=(1, 1, 1),
533+
padding=(0, 0, 0),
534+
layout="NCDHW",
535+
ceil_mode=False,
536+
count_include_pad=False):
537+
r"""3D average pooling operator.
538+
539+
This operator takes data as input and does 3D average value calculation
540+
with in pool_size sized window by striding defined by stride
541+
542+
543+
In the default case, where the data_layout is `NCDHW`
544+
a data Tensor with shape `(batch_size, channels, depthm height, width)`,
545+
to produce an output Tensor.
546+
547+
The ceil_mode is used to take ceil or floor while computing out shape.
548+
count_include_pad indicates including or excluding padded input values in computation.
549+
This operator accepts data layout specification.
550+
551+
Parameters
552+
----------
553+
data : tvm.relay.Expr
554+
The input data to the operator.
555+
556+
strides : tuple of int, optional
557+
The strides of pooling.
558+
559+
padding : tuple of int, optional
560+
The padding for pooling.
561+
562+
layout : str, optional
563+
Layout of the input.
564+
565+
ceil_mode : bool, optional
566+
To enable or disable ceil while pooling.
567+
568+
count_include_pad : bool, optional
569+
To include padding to compute the average.
570+
571+
Returns
572+
-------
573+
result : tvm.relay.Expr
574+
The computed result.
575+
"""
576+
return _make.avg_pool3d(data, pool_size, strides, padding,
577+
layout, ceil_mode, count_include_pad)
578+
485579
def max_pool2d_grad(out_grad,
486580
data,
487581
pool_size=(1, 1),

python/tvm/relay/op/op_attrs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,16 @@ class AvgPool2DAttrs(Attrs):
271271
"""Attributes used in avg_pool2d operators"""
272272

273273

274+
@register_relay_attr_node
275+
class MaxPool3DAttrs(Attrs):
276+
"""Attributes used in max_pool3d operators"""
277+
278+
279+
@register_relay_attr_node
280+
class AvgPool3DAttrs(Attrs):
281+
"""Attributes used in avg_pool3d operators"""
282+
283+
274284
@register_relay_attr_node
275285
class BitPackAttrs(Attrs):
276286
"""Attributes used in bitpack operator"""

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -237,16 +237,58 @@ def _test_pooling_iteration(input_shape, **kwargs):
237237
def _test_pooling(input_shape, **kwargs):
238238
_test_pooling_iteration(input_shape, **kwargs)
239239

240-
if is_gpu_available() and (len(input_shape) == 4):
241-
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
242-
kwargs['data_format'] = 'NCHW'
243-
_test_pooling_iteration(input_shape, **kwargs)
240+
if is_gpu_available():
241+
if len(input_shape) == 4:
242+
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
243+
kwargs['data_format'] = 'NCHW'
244+
_test_pooling_iteration(input_shape, **kwargs)
244245

245246

246247
def test_forward_pooling():
247248
""" Pooling """
248-
249+
# TensorFlow only supports NDHWC for max_pool3d on CPU
249250
for pool_type in ['AVG', 'MAX']:
251+
# NDHWC is the default layout for max_pool3d and avg_pool3d in TensorFlow
252+
_test_pooling(input_shape=[1, 3, 32, 32, 32],
253+
window_shape=[2, 2, 2],
254+
padding='VALID',
255+
pooling_type=pool_type,
256+
dilation_rate=[1, 1, 1],
257+
strides=[2, 2, 2])
258+
259+
_test_pooling(input_shape=[1, 3, 32, 32, 32],
260+
window_shape=[1, 1, 1],
261+
padding='SAME',
262+
pooling_type=pool_type,
263+
dilation_rate=[1, 1, 1],
264+
strides=[1, 1, 1])
265+
266+
_test_pooling(input_shape=[1, 3, 32, 32, 32],
267+
window_shape=[2, 2, 2],
268+
padding='SAME',
269+
pooling_type=pool_type,
270+
dilation_rate=[1, 1, 1],
271+
strides=[2, 2, 2])
272+
273+
# test cases for max_pool3d & avg_pool3d with layout NCDHW
274+
# TensorFlow pool3d doesn't support NCDHW on cpu
275+
if is_gpu_available():
276+
_test_pooling(input_shape=[1, 3, 32, 32, 32],
277+
window_shape=[1, 1, 1],
278+
padding='SAME',
279+
pooling_type=pool_type,
280+
dilation_rate=[1, 1, 1],
281+
strides=[1, 1, 1],
282+
data_format='NCDHW')
283+
284+
_test_pooling(input_shape=[1, 3, 32, 32, 32],
285+
window_shape=[2, 2, 2],
286+
padding='VALID',
287+
pooling_type=pool_type,
288+
dilation_rate=[1, 1, 1],
289+
strides=[2, 2, 2],
290+
data_format='NCDHW')
291+
250292
_test_pooling(input_shape=[2, 9, 10, 2],
251293
window_shape=[1, 1],
252294
padding='SAME',
@@ -2855,7 +2897,6 @@ def test_forward_add_n():
28552897
test_forward_sin()
28562898
test_forward_negative()
28572899
test_forward_divide()
2858-
test_forward_floordiv()
28592900
test_forward_abs()
28602901
test_forward_softplus()
28612902
test_forward_sqrt()
@@ -2916,5 +2957,3 @@ def test_forward_add_n():
29162957
test_forward_where()
29172958
test_forward_matmul()
29182959
test_forward_batch_matmul()
2919-
2920-
# TODO missing tests: rank

0 commit comments

Comments
 (0)