Skip to content
100 changes: 87 additions & 13 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@

__all__ = ['from_tensorflow']

def _infer_value(input_val, params):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add additional check to make sure all inputs to infer is available in params.
A utility similar to list_input_names() in nnvm.

Not sure if we have similar API in relay.
cc @zhiics @jroesch @tqchen

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think something like this function could help?
https://github.com/dmlc/tvm/blob/master/python/tvm/relay/backend/interpreter.py#L136

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt it fits here.
What we need here is list of Var nodes current Expr is dependent on recursively.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just take free_vars if it is an expr? Var has name_hint.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

free_vars should work.

@jwfromm can you cross check free_vars with params and add an assert ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in the latest push.

from tvm.contrib import graph_runtime
# Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in ir_pass.free_vars(
input_val)), "All inputs to infer must be available in params."
func = _expr.Function(ir_pass.free_vars(input_val), input_val)
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.context("llvm", 0)
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
return m.get_output(0)

def _get_relay_op(op_name):
try:
op = getattr(_op, op_name)
Expand Down Expand Up @@ -465,7 +479,12 @@ def _impl(inputs, attr, params):

def _resize_bilinear():
def _impl(inputs, attr, params):
attr['size'] = attr['_output_shapes'][0][1:3]
size = attr['_output_shapes'][0][1:3]
# Important that the size is defined. If an axis is not, we need to infer what
# the shape should be.
if -1 in size:
size = _infer_value(inputs[1], params).asnumpy().reshape([-1]).tolist()
attr['size'] = size
inputs.pop(1)
# NHWC
attr['layout'] = 'NHWC'
Expand Down Expand Up @@ -574,25 +593,71 @@ def _impl(inputs, attr, params):
except AttributeError:
# Shape operator is already pruned, hence
# try to infer shape by precompute prune if possible.
func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1])
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.context("llvm", 0)
from tvm.contrib import graph_runtime
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
m.run()
params_new = m.get_output(0)
params_new = _infer_value(inputs[1], params)
inputs.pop(1)
return AttrCvt(
op_name="reshape",
extras={'newshape':tuple(params_new.asnumpy().astype('int64').flatten())},
ignores=['Tshape'])(inputs, attr)
return _impl


def _depth_to_space():
def _impl(inputs, attr, params):
# Need to handle data layouts differently.
input_shape = attr['_input_shapes'][inputs[0]]
block_size = int(attr['block_size'])
if attr['data_format'].decode("utf-8") == 'NHWC':
in_n, in_h, in_w, in_c = input_shape
new_c = int(in_c / (block_size * block_size))

# First expand input to larger dimension.
expanded = _op.reshape(
inputs[0], newshape=(in_n, in_h, in_w, block_size, block_size, new_c))
# Now reorder to expand spatial blocks.
transposed = _op.transpose(expanded, axes=(0, 1, 3, 2, 4, 5))
# Finally reshape to proper output.
new_h = in_h * block_size
new_w = in_w * block_size
newshape = (in_n, new_h, new_w, new_c)

else: # Handle NCHW layout
in_n, in_c, in_h, in_w = input_shape
new_c = int(in_c / (block_size * block_size))

expanded = _op.reshape(
inputs[0], newshape=(in_n, block_size, block_size, new_c, in_h, in_w))
transposed = _op.transpose(expanded, axes=(0, 3, 4, 1, 5, 2))
new_h = in_h * block_size
new_w = in_w * block_size
newshape = (in_n, new_c, new_h, new_w)

return AttrCvt(
op_name="reshape",
extras={'newshape': newshape},
ignores=['data_format', 'block_size'])([transposed], attr)

return _impl


def _bias_add():
def _impl(inputs, attr, params):
return _op.add(inputs[0], inputs[1])
# Must expand for proper broadcasting in NCHW.
if attr['data_format'].decode("utf-8") == 'NCHW':
bias = _op.reshape(inputs[1], newshape=(1, -1, 1, 1))
else:
bias = inputs[1]
return _op.add(inputs[0], bias)
return _impl

def _broadcast_to():
def _impl(inputs, attr, params):
if isinstance(inputs[1], _expr.Var):
shape = params[inputs[1].name_hint]
else:
shape = _infer_value(inputs[1], params)
shape = list(shape.asnumpy().reshape([-1]))
return _op.broadcast_to(inputs[0], shape)
return _impl

def _squeeze():
Expand Down Expand Up @@ -666,9 +731,15 @@ def _impl(inputs, attr, params):

def _fill():
def _impl(inputs, attr, params):
output_shape = attr['_output_shapes'][0]
# Output shape must be defined to avoid errors. If any axis is not, we must
# try to compute its shape.
if -1 in output_shape:
output_shape = _infer_value(inputs[0], params).asnumpy().reshape([-1]).tolist()

fill_arg = params.pop(inputs.pop(1).name_hint)
return _op.full(tvm.relay.const(fill_arg.asnumpy()[0], attr['T'].name),
attr['_output_shapes'][0], attr['T'].name)
output_shape, attr['T'].name)
return _impl

def _lrn():
Expand Down Expand Up @@ -1115,6 +1186,7 @@ def _impl(inputs, attr, params):
'BatchNormWithGlobalNormalization' : _batch_norm(),
'BatchToSpaceND' : _batch_to_space_nd(),
'BiasAdd' : _bias_add(),
'BroadcastTo' : _broadcast_to(),
'Cast' : _cast(),
'Ceil' : AttrCvt('ceil'),
'CheckNumerics' : _check_numerics(),
Expand All @@ -1123,6 +1195,7 @@ def _impl(inputs, attr, params):
'Conv2D' : _conv('conv'),
'DecodeJpeg' : _decode_image(),
'DepthwiseConv2dNative' : _conv('depthwise'),
'DepthToSpace' : _depth_to_space(),
'Equal' : _broadcast('equal'),
'Elu' : _elu(),
'Exp' : AttrCvt('exp'),
Expand Down Expand Up @@ -1158,11 +1231,12 @@ def _impl(inputs, attr, params):
'Prod' : _prod(),
'Range' : _range(),
'Rank' : _rank(),
'RealDiv' : _elemwise('div'),
'RealDiv' : _elemwise('divide'),
'Relu' : AttrCvt('relu'),
'Relu6' : _relu6(),
'Reshape' : _reshape(),
'ResizeBilinear' : _resize_bilinear(),
'ResizeBicubic' : _resize_bilinear(),
'ReverseV2' : _reverse_v2(),
'Round' : AttrCvt('round'),
'Rsqrt' : _rsqrt(),
Expand Down
131 changes: 123 additions & 8 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def convert_to_list(x):
x = [x]
return x

def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm', out_names=None):
def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
target='llvm', out_names=None, opt_level=3):
""" Generic function to compile on relay and execute on tvm """
input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node)
Expand All @@ -71,7 +72,7 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm'
layout=layout,
shape=shape_dict,
outputs=out_names)
with relay.build_config(opt_level=3):
with relay.build_config(opt_level=opt_level):
graph, lib, params = relay.build(sym, target, params=params)

ctx = tvm.context(target, 0)
Expand All @@ -85,8 +86,8 @@ def run_tvm_graph(graph_def, input_data, input_node, num_output=1, target='llvm'
# execute
m.run()
# get outputs
assert out_names is None or num_output == len(out_names),"out_names: {} num_output: {}".format(
out_names, num_output)
assert out_names is None or num_output == len(out_names), (
"out_names: {} num_output: {}".format(out_names, num_output))
tvm_output_list = []
for i in range(0, num_output):
tvm_output = m.get_output(i)
Expand All @@ -111,7 +112,8 @@ def run_tf_graph(sess, input_data, input_node, output_node):
return output_data


def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False, no_gpu=False):
def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
no_gpu=False, opt_level=3):
"""Generic function to generate and compare tensorflow and TVM output"""

out_name = convert_to_list(out_name)
Expand Down Expand Up @@ -142,8 +144,9 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
if no_gpu and device == 'cuda':
continue

tvm_output = run_tvm_graph(final_graph_def, in_data, in_node, target=device,
out_names=out_name, num_output=len(out_name))
tvm_output = run_tvm_graph(final_graph_def, in_data, in_node,
target=device, out_names=out_name,
num_output=len(out_name), opt_level=opt_level)
# since the names from tensorflow and relay runs are not exactly same,
# first len(tf_output) will be compared
for i in range(len(tf_output)):
Expand Down Expand Up @@ -411,6 +414,23 @@ def test_forward_reshape():
_test_reshape(np.arange(6), [-1])

#######################################################################
# DepthToSpace
# ------------

def _test_depthtospace(data, block_size):
""" One iteration of depth_to_space operation with given data and block size """

with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
array_ops.depth_to_space(in_data, block_size)

compare_tf_with_tvm(data, 'Placeholder:0', 'DepthToSpace:0')

def test_forward_depthtospace():
_test_depthtospace(np.random.normal(size=[1, 32, 32, 4]), 2)
_test_depthtospace(np.random.normal(size=[1, 16, 8, 32]), 4)


#######################################################################
# Squeeze
# -------
Expand Down Expand Up @@ -840,16 +860,108 @@ def _test_resize_bilinear(in_shape, to_shape, align_corners):

with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
shape_data = constant_op.constant(shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
shape_data = constant_op.constant(
shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
tf.image.resize_bilinear(in_data, shape_data, align_corners=align_corners)

compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0')

def _test_resize_bilinear_from_tensor(in_shape, align_corners):
""" One iteration of resize bilinear with non-constant output shape, requires
value inference to get proper output shape."""

data = np.random.uniform(size=in_shape).astype('float32')

with tf.Graph().as_default():
in_data = array_ops.placeholder(
shape=[in_shape[0], in_shape[1], None, None], dtype=data.dtype)
to_shape = tf.shape(in_data)[2:]
tf.image.resize_bilinear(in_data, to_shape, align_corners=align_corners)

compare_tf_with_tvm(data, 'Placeholder:0', 'ResizeBilinear:0')

def test_forward_resize_bilinear():
""" Resize Bilinear """

_test_resize_bilinear((4, 16, 32, 32), [50, 50], False)
_test_resize_bilinear((6, 32, 64, 64), [20, 20], True)
_test_resize_bilinear_from_tensor((4, 16, 32, 32), False)
_test_resize_bilinear_from_tensor((6, 32, 50, 50), True)

#######################################################################
# BroadcastTo
# -----------

def _test_broadcast_to(in_shape, to_shape):
""" One iteration of broadcast_to"""

data = np.random.uniform(size=in_shape).astype('float32')
shape_data = np.array(to_shape).astype('int32')

with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
shape_data = constant_op.constant(
shape_data, shape=shape_data.shape, dtype=shape_data.dtype)
tf.broadcast_to(in_data, shape_data)

compare_tf_with_tvm(data, 'Placeholder:0', 'BroadcastTo:0', opt_level=0)


def _test_broadcast_to_from_tensor(in_shape):
""" One iteration of broadcast_to with unknown shape at graph build"""

data = np.random.uniform(size=in_shape).astype('float32')

with tf.Graph().as_default():
in_data = array_ops.placeholder(
shape=[None], dtype=data.dtype)

shape_data = tf.multiply(tf.shape(in_data), 32)
tf.broadcast_to(in_data, shape_data)

compare_tf_with_tvm(data, 'Placeholder:0', 'BroadcastTo:0')


def test_forward_broadcast_to():
""" Resize Bilinear """

_test_broadcast_to((4, 1, 32, 32), [4, 8, 32, 32])
_test_broadcast_to((6, 32, 32, 1), [6, 32, 32, 16])
_test_broadcast_to_from_tensor((1))


#######################################################################
# Fill
# ----

def _test_fill(in_shape):
""" Use the fill op to create a tensor of ones with non-constant shape."""

with tf.Graph().as_default():
tf.ones(shape=in_shape, dtype='float32')
compare_tf_with_tvm(in_shape, [], 'ones:0', opt_level=1)

def _test_fill_from_tensor(in_shape):
""" Use the fill op to create a tensor of ones with non-constant shape.
Some extra ops need to be added here to prevent the graph from
being fully constant and folded away."""

data = np.random.uniform(size=in_shape).astype('float32')

with tf.Graph().as_default():
in_data = array_ops.placeholder(
shape=[in_shape[0], in_shape[1], None, None], dtype=data.dtype)

x = tf.ones(shape=2*tf.shape(in_data), dtype=data.dtype)
y = tf.math.add(in_data, tf.reduce_mean(x), name='out1')
compare_tf_with_tvm(data, 'Placeholder:0', 'out1:0')

def test_forward_fill():
""" Resize Bilinear """

_test_fill((32))
_test_fill((6, 32, 64, 64))
_test_fill_from_tensor((6, 32, 64, 64))

#######################################################################
# Crop to bounding box
Expand Down Expand Up @@ -1549,9 +1661,12 @@ def test_forward_reduce_prod():
# Transforms
test_forward_transpose()
test_forward_reshape()
test_forward_depthtospace()
test_forward_squeeze()
test_forward_pack()
test_forward_resize_bilinear()
test_forward_broadcast_to()
test_forward_fill()
test_forward_crop()
test_forward_pad()
test_forward_gather()
Expand Down