Skip to content

Commit 0455afb

Browse files
committed
[Relay] Add ResizeNearestNeighbor and CropAndResize in tf converter
1 parent 25bad44 commit 0455afb

File tree

3 files changed

+127
-10
lines changed

3 files changed

+127
-10
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,54 @@ def _impl(inputs, attr, params):
484484
return inputs[0]
485485
return _impl
486486

487+
def _crop_and_resize():
488+
def _impl(inputs, attr, params):
489+
# input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
490+
# boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2]
491+
try:
492+
boxes = params.pop(inputs[1].name_hint).asnumpy().tolist()
493+
box_ind = params.pop(inputs[2].name_hint).asnumpy().tolist()
494+
crop_size = params.pop(inputs[3].name_hint).asnumpy().tolist()
495+
except (IndexError, KeyError):
496+
boxes = _infer_value(inputs[1], params).asnumpy().tolist()
497+
box_ind = _infer_value(inputs[2], params).asnumpy().tolist()
498+
crop_size = _infer_value(inputs[3], params).asnumpy().tolist()
499+
500+
data_shape = attr['_input_shapes'][inputs[0]]
501+
data_dim = len(data_shape)
502+
method = attr['method'].decode()
503+
504+
attrs = {}
505+
attrs['size'] = crop_size
506+
attrs['layout'] = 'NHWC'
507+
if method.lower() == 'nearest':
508+
raise tvm.error.OpAttributeUnimplemented(
509+
'Attribute method=nearest is not supported')
510+
else:
511+
attrs['align_corners'] = True
512+
attrs['method'] = 'BILINEAR'
513+
514+
out = None
515+
begin = [0] * data_dim
516+
size = data_shape[:]
517+
for idx in box_ind:
518+
# 1) Crop
519+
# y is mapped to the image coordinate at y * (image_height - 1)
520+
# x is mapped to the image coordinate at x * (image_width - 1)
521+
begin[0] = idx
522+
begin[1] = int(round(boxes[idx][0] * (data_shape[1] - 1)))
523+
begin[2] = int(round(boxes[idx][1] * (data_shape[2] - 1)))
524+
size[0] = idx + 1
525+
size[1] = int(round((data_shape[1] - 1) * boxes[idx][2])) + 1
526+
size[2] = int(round((data_shape[2] - 1) * boxes[idx][3])) + 1
527+
res_crop = _op.strided_slice(inputs[0], begin=begin, end=size)
528+
529+
# 2) Resize
530+
res_resize = _get_relay_op('resize')(res_crop, **attrs)
531+
out = _op.concatenate([out, res_resize], axis=0) if out else res_resize
532+
return out
533+
return _impl
534+
487535
def _cast():
488536
def _impl(inputs, attr, params):
489537
return inputs[0].astype(attr['DstT'].name)
@@ -514,6 +562,21 @@ def _impl(inputs, attr, params):
514562
extras={'method': "BILINEAR"})(inputs, attr)
515563
return _impl
516564

565+
def _resize_nearest_neighbor():
566+
def _impl(inputs, attr, params):
567+
size = attr['_output_shapes'][0][1:3]
568+
if -1 in size:
569+
size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist()
570+
attr['size'] = size
571+
inputs.pop(1)
572+
# NHWC
573+
attr['layout'] = 'NHWC'
574+
575+
return AttrCvt(op_name="resize",
576+
ignores=['Tdim'],
577+
extras={'method': "NEAREST_NEIGHBOR"})(inputs, attr)
578+
return _impl
579+
517580
def _check_numerics():
518581
def _impl(inputs, attr, params):
519582
# Making a copy node assuming no need to verify
@@ -593,7 +656,7 @@ def _impl(inputs, attr, params):
593656
end[i] = data_shape[i] - begin[i]
594657
else:
595658
end[i] += begin[i]
596-
return _op.strided_slice(inputs[0], begin=begin, end=size)
659+
return _op.strided_slice(inputs[0], begin=begin, end=end)
597660
return _impl
598661

599662

@@ -1243,6 +1306,7 @@ def _impl(inputs, attr, params):
12431306
'Concat' : _concat(),
12441307
'ConcatV2' : _concatV2(),
12451308
'Conv2D' : _conv('conv'),
1309+
'CropAndResize' : _crop_and_resize(),
12461310
'DecodeJpeg' : _decode_image(),
12471311
'DepthwiseConv2dNative' : _conv('depthwise'),
12481312
'DepthToSpace' : _depth_to_space(),
@@ -1295,6 +1359,7 @@ def _impl(inputs, attr, params):
12951359
'Reshape' : _reshape(),
12961360
'ResizeBilinear' : _resize_bilinear(),
12971361
'ResizeBicubic' : _resize_bilinear(),
1362+
'ResizeNearestNeighbor' : _resize_nearest_neighbor(),
12981363
'ReverseV2' : _reverse_v2(),
12991364
'RightShift' : AttrCvt('right_shift'),
13001365
'Round' : AttrCvt('round'),

python/tvm/relay/frontend/tensorflow_parser.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from __future__ import absolute_import as _abs
1919
from __future__ import print_function
2020
import os
21-
from tensorflow.core.framework import graph_pb2
2221
from tvm.contrib import util
2322

2423

@@ -35,12 +34,12 @@ class TFParser(object):
3534
--------
3635
.. code-block:: python
3736
38-
parser = TfParser(model_dir)
39-
graph = parser.parse()
40-
# graph is related graphdef of the model
37+
parser = TFParser(model_dir)
38+
graphdef = parser.parse()
4139
"""
4240

4341
def __init__(self, model_dir):
42+
from tensorflow.core.framework import graph_pb2
4443
self._tmp_dir = util.tempdir()
4544
self._model_dir = model_dir
4645
self._graph = graph_pb2.GraphDef()
@@ -96,6 +95,7 @@ def _load_saved_model(self):
9695
from tensorflow.python.tools import freeze_graph
9796
from tensorflow.python.framework import ops
9897
from tensorflow.python.framework import graph_util
98+
from tensorflow.core.framework import graph_pb2
9999
except ImportError:
100100
raise ImportError(
101101
"InputConfiguration: Unable to import tensorflow which is "

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -949,8 +949,8 @@ def test_forward_multi_output():
949949
tvm.testing.assert_allclose(tf_output[i], tvm_output[i], atol=1e-5, rtol=1e-5)
950950

951951
#######################################################################
952-
# Resize Bilinear
953-
# ---------------
952+
# Resize Bilinear, Nearest_Neighbor
953+
# ---------------------------------
954954

955955
def _test_resize_bilinear(in_shape, to_shape, align_corners):
956956
""" One iteration of resize bilinear """
@@ -980,13 +980,31 @@ def _test_resize_bilinear_from_tensor(in_shape, align_corners):
980980

981981
compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0')
982982

983-
def test_forward_resize_bilinear():
984-
""" Resize Bilinear """
983+
984+
def _test_resize_nearest_neighbor(in_shape, to_shape):
985+
""" One iteration of resize nearest neighbor """
986+
987+
data = np.random.uniform(size=in_shape).astype('float32')
988+
shape_data = np.array(to_shape).astype('int32')
989+
990+
with tf.Graph().as_default():
991+
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
992+
shape_data = constant_op.constant(
993+
shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
994+
tf.image.resize_nearest_neighbor(in_data, shape_data, name='resize_nearest_neighbor')
995+
996+
compare_tf_with_tvm(data, 'Placeholder:0', 'resize_nearest_neighbor:0')
997+
998+
999+
def test_forward_resize():
1000+
""" Resize Bilinear, Nearest_Neighbor """
9851001

9861002
_test_resize_bilinear((4, 16, 32, 32), [50, 50], False)
9871003
_test_resize_bilinear((6, 32, 64, 64), [20, 20], True)
9881004
_test_resize_bilinear_from_tensor((4, 16, 32, 32), False)
9891005
_test_resize_bilinear_from_tensor((6, 32, 50, 50), True)
1006+
_test_resize_nearest_neighbor((6, 32, 64, 64), [20, 20])
1007+
9901008

9911009
#######################################################################
9921010
# BroadcastTo
@@ -1080,6 +1098,39 @@ def test_forward_crop():
10801098
_test_crop((1, 224, 224, 3), 20, 20, 120, 120)
10811099

10821100

1101+
#######################################################################
1102+
# CropAndResize
1103+
# -------------
1104+
1105+
def _test_forward_crop_and_resize(img_shape, boxes, box_idx, crop_size, method='bilinear', dtype="float32"):
1106+
image = np.random.uniform(0, 10, size=img_shape).astype(dtype)
1107+
tf.reset_default_graph()
1108+
in_data = tf.placeholder(dtype, image.shape, name="in_data")
1109+
tf.image.crop_and_resize(in_data, boxes=boxes, box_ind=box_idx, crop_size=crop_size,
1110+
method=method, name="crop_and_resize")
1111+
compare_tf_with_tvm([image], ['in_data:0'], 'crop_and_resize:0')
1112+
1113+
def test_forward_crop_and_resize():
1114+
""" CropAndResize """
1115+
_test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, 1, 1]], [0], [5, 5])
1116+
_test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, .9, .9]], [0], [5, 5])
1117+
_test_forward_crop_and_resize([1, 11, 11, 3], [[.1, .2, 1, 1]], [0], [5, 5])
1118+
_test_forward_crop_and_resize([1, 21, 21, 3], [[.2, .3, .7, .9]], [0], [3, 4])
1119+
_test_forward_crop_and_resize([1, 106, 106, 3], [[0.2, 0.4, 0.8, 0.8]], [0], [3, 3])
1120+
_test_forward_crop_and_resize([10, 11, 11, 3],
1121+
[[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]],
1122+
[0, 1],
1123+
[5, 5])
1124+
_test_forward_crop_and_resize([3, 11, 11, 3],
1125+
[[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8],[0, 0, 1, 1]],
1126+
[0, 1, 2],
1127+
[3, 3])
1128+
_test_forward_crop_and_resize([3, 11, 11, 3],
1129+
[[0, 0, 1, 0.8], [0, 0, 0.9, 0.9], [0, 0, 1, 0.8]],
1130+
[2, 1, 0],
1131+
[3, 3])
1132+
1133+
10831134
#######################################################################
10841135
# LSTM
10851136
# ----
@@ -1979,10 +2030,11 @@ def test_placeholder():
19792030
test_forward_depthtospace()
19802031
test_forward_squeeze()
19812032
test_forward_pack()
1982-
test_forward_resize_bilinear()
19832033
test_forward_broadcast_to()
19842034
test_forward_fill()
19852035
test_forward_crop()
2036+
test_forward_resize()
2037+
test_forward_crop_and_resize()
19862038
test_forward_pad()
19872039
test_forward_unpack()
19882040
test_forward_gather()

0 commit comments

Comments
 (0)