Skip to content

Commit b9734a0

Browse files
committed
[Relay] Add ResizeNearestNeighbor and CropAndResize in tf converter
1 parent df6957a commit b9734a0

File tree

3 files changed

+131
-10
lines changed

3 files changed

+131
-10
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,60 @@ def _impl(inputs, attr, params):
484484
return inputs[0]
485485
return _impl
486486

487+
def _crop_and_resize():
488+
def _impl(inputs, attr, params):
489+
# input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
490+
# boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2]
491+
boxes = params.pop(inputs[1].name_hint).asnumpy()
492+
data_shape = attr['_input_shapes'][inputs[0]]
493+
data_dim = len(data_shape)
494+
box_ind = params.pop(inputs[2].name_hint).asnumpy().tolist()
495+
crop_size = params.pop(inputs[3].name_hint).asnumpy().tolist()
496+
method = attr['method'].decode()
497+
498+
# 1) Crop
499+
# y is mapped to the image coordinate at y * (image_height - 1)
500+
# x is mapped to the image coordinate at x * (image_width - 1)
501+
begin = [0] * data_dim
502+
begin[1] = float(boxes[0][0]) * (data_shape[1] - 1)
503+
begin[2] = int(round(boxes[0][1] * (data_shape[2] - 1)))
504+
size = data_shape[:]
505+
size[0] = 1
506+
size[1] = int(round((data_shape[1] - 1) * boxes[0][2])) + 1
507+
size[2] = int(round((data_shape[2] - 1) * boxes[0][3])) + 1
508+
res_crop = _op.strided_slice(inputs[0], begin=begin, end=size)
509+
510+
# 2) Resize
511+
attrs = {};
512+
attrs['size'] = crop_size
513+
attrs['layout'] = 'NHWC'
514+
if method.lower() == 'nearest':
515+
raise tvm.error.OpAttributeUnimplemented(
516+
'Attribute method=nearest is not supported')
517+
else:
518+
attrs['align_corners'] = True
519+
attrs['method'] = 'BILINEAR'
520+
ret = _get_relay_op('resize')(res_crop, **attrs)
521+
522+
for idx in box_ind[1:]:
523+
# 1) Crop
524+
# y is mapped to the image coordinate at y * (image_height - 1)
525+
begin = [0] * data_dim
526+
begin[0] = idx
527+
begin[1] = int(round(boxes[idx][0] * (data_shape[1] - 1)))
528+
begin[2] = int(round(boxes[idx][1] * (data_shape[2] - 1)))
529+
size = data_shape[:]
530+
size[0] = idx + 1
531+
size[1] = int(round((data_shape[1] - 1) * boxes[idx][2])) + 1
532+
size[2] = int(round((data_shape[2] - 1) * boxes[idx][3])) + 1
533+
res_crop = _op.strided_slice(inputs[0], begin=begin, end=size)
534+
535+
# 2) Resize
536+
res_resize = _get_relay_op('resize')(res_crop, **attrs)
537+
ret = _op.concatenate([ret, res_resize], axis=0)
538+
return ret
539+
return _impl
540+
487541
def _cast():
488542
def _impl(inputs, attr, params):
489543
return inputs[0].astype(attr['DstT'].name)
@@ -514,6 +568,21 @@ def _impl(inputs, attr, params):
514568
extras={'method': "BILINEAR"})(inputs, attr)
515569
return _impl
516570

571+
def _resize_nearest_neighbor():
572+
def _impl(inputs, attr, params):
573+
size = attr['_output_shapes'][0][1:3]
574+
if -1 in size:
575+
size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist()
576+
attr['size'] = size
577+
inputs.pop(1)
578+
# NHWC
579+
attr['layout'] = 'NHWC'
580+
581+
return AttrCvt(op_name="resize",
582+
ignores=['Tdim'],
583+
extras={'method': "NEAREST_NEIGHBOR"})(inputs, attr)
584+
return _impl
585+
517586
def _check_numerics():
518587
def _impl(inputs, attr, params):
519588
# Making a copy node assuming no need to verify
@@ -593,7 +662,7 @@ def _impl(inputs, attr, params):
593662
end[i] = data_shape[i] - begin[i]
594663
else:
595664
end[i] += begin[i]
596-
return _op.strided_slice(inputs[0], begin=begin, end=size)
665+
return _op.strided_slice(inputs[0], begin=begin, end=end)
597666
return _impl
598667

599668

@@ -1243,6 +1312,7 @@ def _impl(inputs, attr, params):
12431312
'Concat' : _concat(),
12441313
'ConcatV2' : _concatV2(),
12451314
'Conv2D' : _conv('conv'),
1315+
'CropAndResize' : _crop_and_resize(),
12461316
'DecodeJpeg' : _decode_image(),
12471317
'DepthwiseConv2dNative' : _conv('depthwise'),
12481318
'DepthToSpace' : _depth_to_space(),
@@ -1295,6 +1365,7 @@ def _impl(inputs, attr, params):
12951365
'Reshape' : _reshape(),
12961366
'ResizeBilinear' : _resize_bilinear(),
12971367
'ResizeBicubic' : _resize_bilinear(),
1368+
'ResizeNearestNeighbor' : _resize_nearest_neighbor(),
12981369
'ReverseV2' : _reverse_v2(),
12991370
'RightShift' : AttrCvt('right_shift'),
13001371
'Round' : AttrCvt('round'),

python/tvm/relay/frontend/tensorflow_parser.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from __future__ import absolute_import as _abs
1919
from __future__ import print_function
2020
import os
21-
from tensorflow.core.framework import graph_pb2
2221
from tvm.contrib import util
2322

2423

@@ -35,12 +34,12 @@ class TFParser(object):
3534
--------
3635
.. code-block:: python
3736
38-
parser = TfParser(model_dir)
39-
graph = parser.parse()
40-
# graph is related graphdef of the model
37+
parser = TFParser(model_dir)
38+
graphdef = parser.parse()
4139
"""
4240

4341
def __init__(self, model_dir):
42+
from tensorflow.core.framework import graph_pb2
4443
self._tmp_dir = util.tempdir()
4544
self._model_dir = model_dir
4645
self._graph = graph_pb2.GraphDef()
@@ -96,6 +95,7 @@ def _load_saved_model(self):
9695
from tensorflow.python.tools import freeze_graph
9796
from tensorflow.python.framework import ops
9897
from tensorflow.python.framework import graph_util
98+
from tensorflow.core.framework import graph_pb2
9999
except ImportError:
100100
raise ImportError(
101101
"InputConfiguration: Unable to import tensorflow which is "

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -969,8 +969,8 @@ def test_forward_multi_output():
969969
tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
970970

971971
#######################################################################
972-
# Resize Bilinear
973-
# ---------------
972+
# Resize Bilinear, Nearest_Neighbor
973+
# ---------------------------------
974974

975975
def _test_resize_bilinear(in_shape, to_shape, align_corners):
976976
""" One iteration of resize bilinear """
@@ -1000,13 +1000,31 @@ def _test_resize_bilinear_from_tensor(in_shape, align_corners):
10001000

10011001
compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0')
10021002

1003-
def test_forward_resize_bilinear():
1004-
""" Resize Bilinear """
1003+
1004+
def _test_resize_nearest_neighbor(in_shape, to_shape):
1005+
""" One iteration of resize nearest neighbor """
1006+
1007+
data = np.random.uniform(size=in_shape).astype('float32')
1008+
shape_data = np.array(to_shape).astype('int32')
1009+
1010+
with tf.Graph().as_default():
1011+
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
1012+
shape_data = constant_op.constant(
1013+
shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
1014+
tf.image.resize_nearest_neighbor(in_data, shape_data, name='resize_nearest_neighbor')
1015+
1016+
compare_tf_with_tvm(data, 'Placeholder:0', 'resize_nearest_neighbor:0')
1017+
1018+
1019+
def test_forward_resize():
1020+
""" Resize Bilinear, Nearest_Neighbor """
10051021

10061022
_test_resize_bilinear((4, 16, 32, 32), [50, 50], False)
10071023
_test_resize_bilinear((6, 32, 64, 64), [20, 20], True)
10081024
_test_resize_bilinear_from_tensor((4, 16, 32, 32), False)
10091025
_test_resize_bilinear_from_tensor((6, 32, 50, 50), True)
1026+
_test_resize_nearest_neighbor((6, 32, 64, 64), [20, 20])
1027+
10101028

10111029
#######################################################################
10121030
# BroadcastTo
@@ -1100,6 +1118,37 @@ def test_forward_crop():
11001118
_test_crop((1, 224, 224, 3), 20, 20, 120, 120)
11011119

11021120

1121+
#######################################################################
1122+
# CropAndResize
1123+
# -------------
1124+
1125+
def _test_forward_crop_and_resize(img_shape, boxes, box_idx, crop_size, method='bilinear', dtype="float32"):
1126+
image = np.random.uniform(0, 10, size=img_shape).astype(dtype)
1127+
tf.reset_default_graph()
1128+
in_data = tf.placeholder(dtype, image.shape, name="in_data")
1129+
tf.image.crop_and_resize(in_data, boxes=boxes, box_ind=box_idx, crop_size=crop_size,
1130+
method=method, name="crop_and_resize")
1131+
compare_tf_with_tvm([image], ['in_data:0'], 'crop_and_resize:0')
1132+
1133+
def test_forward_crop_and_resize():
1134+
""" CropAndResize """
1135+
_test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, 1, 1]],
1136+
[0], [5, 5])
1137+
_test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, .9, .9]],
1138+
[0], [5, 5])
1139+
_test_forward_crop_and_resize([1, 11, 11, 3], [[.1, .2, 1, 1]],
1140+
[0], [5, 5])
1141+
_test_forward_crop_and_resize([1, 21, 21, 3], [[.2, .3, .7, .9]],
1142+
[0], [3, 4])
1143+
_test_forward_crop_and_resize([10, 11, 11, 3],
1144+
[[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]],
1145+
[0, 1], [5, 5])
1146+
_test_forward_crop_and_resize([3, 11, 11, 3],
1147+
[[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8], [0, 0, 1, 1]],
1148+
[0, 1, 2], [3, 3])
1149+
_test_forward_crop_and_resize([1, 16, 16, 3], [[0.2, 0.4, 0.8, 0.8]], [0], [3, 3])
1150+
1151+
11031152
#######################################################################
11041153
# LSTM
11051154
# ----
@@ -1989,10 +2038,11 @@ def test_placeholder():
19892038
test_forward_depthtospace()
19902039
test_forward_squeeze()
19912040
test_forward_pack()
1992-
test_forward_resize_bilinear()
19932041
test_forward_broadcast_to()
19942042
test_forward_fill()
19952043
test_forward_crop()
2044+
test_forward_resize()
2045+
test_forward_crop_and_resize()
19962046
test_forward_pad()
19972047
test_forward_unpack()
19982048
test_forward_gather()

0 commit comments

Comments
 (0)