Skip to content

Commit 1755c26

Browse files
yongwwwwweic
authored andcommitted
[Relay][Frontend] Support tf.where (apache#2936)
* [Relay][Frontend] Support tf.where * fix comments
1 parent fdb31f5 commit 1755c26

File tree

2 files changed

+39
-18
lines changed

2 files changed

+39
-18
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -683,10 +683,10 @@ def _impl(inputs, attr, params):
683683
new_input = []
684684
new_input.append(inputs.pop(0))
685685
new_input.append(inputs.pop(0))
686-
return AttrCvt(op_name="take",
687-
extras={'axis': tvm.const(axis, 'int32')},
688-
ignores=['Tindices', 'Tparams', 'validate_indices', \
689-
'Taxis', '_class'])(new_input, attr)
686+
return AttrCvt(op_name="take",
687+
extras={'axis': tvm.const(axis, 'int32')},
688+
ignores=['Tindices', 'Tparams', 'validate_indices', \
689+
'Taxis', '_class'])(new_input, attr)
690690
return _impl
691691

692692
def _infer_out_shapes(inputs, params):
@@ -818,7 +818,6 @@ def _impl(inputs, attr, params):
818818
ignores=['Tpaddings'],)(new_inputs, attr)
819819
return _impl
820820

821-
822821
def _transpose():
823822
def _impl(inputs, attr, params):
824823
# If perm is not specified, axes is left empty,
@@ -831,6 +830,11 @@ def _impl(inputs, attr, params):
831830
return _op.transpose(inputs[0], axes=axes)
832831
return _impl
833832

833+
def _where():
834+
def _impl(inputs, attr, params):
835+
return AttrCvt(op_name="where")(inputs, attr)
836+
return _impl
837+
834838
def _rank():
835839
def _impl(inputs, attr, params):
836840
input_shape = attr['_input_shapes'][inputs[0]]
@@ -1015,6 +1019,7 @@ def _impl(inputs, attr, params):
10151019
'DepthwiseConv2dNative' : _conv('depthwise'),
10161020
'Shape' : _shape(),
10171021
'Sigmoid' : AttrCvt('sigmoid'),
1022+
'Select' : _where(),
10181023
'Fill' : _fill(),
10191024
'GatherV2' : _gather(),
10201025
'Gather' : _gather(),

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
108108
in_node = [0]*len(in_name)
109109
for i in range(len(in_name)):
110110
in_node[i] = in_name[i].split(':')[0] if ":" in in_name[i] else in_name[i]
111-
112111
with tf.Session() as sess:
113112
if init_global_variables:
114113
sess.run(variables.global_variables_initializer())
@@ -483,7 +482,7 @@ def _test_gather(ip_shape, indice_shape, indice_value, axis, dtype):
483482
in_data = tf.placeholder(dtype, ip_shape, name="in_data")
484483
indices = tf.placeholder("int32", indice_shape, name="indices")
485484
tf.gather(in_data, indices, axis=axis)
486-
np_data = np.random.uniform(size=ip_shape).astype(dtype)
485+
np_data = np.random.uniform(1, 10, size=ip_shape).astype(dtype)
487486

488487
def _fill_indices(indice_value):
489488
indices = np.array(ip_shape, dtype=dtype)
@@ -500,14 +499,14 @@ def test_forward_gather():
500499
'''test GatherV2 layer'''
501500
_test_gather((4,), (1,), 1, 0, 'int32')
502501
_test_gather((4,), (1,), 1, 0, 'float32')
503-
_test_gather((1,4), (1,), [0], 0, 'int32')
504-
_test_gather((4,), (1,2,2), [[[1,0],[0,1]]], 0, 'float32')
505-
_test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 0, 'int32')
506-
_test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 1, 'int32')
507-
_test_gather((2,2), (1,2,2), [[[1,0],[0,1]]], 0, 'float32')
508-
_test_gather((3,3,3), (1,1,2), [[[1,0]]], 0, 'int32')
509-
_test_gather((3,3,3), (1,1,2), [[[1,0]]], 2, 'int32')
510-
_test_gather((4,3,5,6), (1,4), [[2,1,0,0]], 0, 'float32')
502+
_test_gather((1, 4), (1,), [0], 0, 'int32')
503+
_test_gather((4,), (1, 2, 2), [[[1, 0],[0, 1]]], 0, 'float32')
504+
_test_gather((2, 2), (1, 2, 2), [[[1, 0],[0, 1]]], 0, 'int32')
505+
_test_gather((2, 2), (1, 2, 2), [[[1, 0],[0, 1]]], 1, 'int32')
506+
_test_gather((2, 2), (1, 2, 2), [[[1, 0],[0, 1]]], 0, 'float32')
507+
_test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 0, 'int32')
508+
_test_gather((3, 3, 3), (1, 1, 2), [[[1, 0]]], 2, 'int32')
509+
_test_gather((4, 3, 5, 6), (1, 4), [[2, 1, 0, 0]], 0, 'float32')
511510

512511

513512
def _test_gather_v1(ip_shape, indice_shape, indice_value, dtype):
@@ -620,10 +619,10 @@ def _test_unstack(ip_shape, axis, dtype):
620619
def test_forward_unstack():
621620
'''test unstack layer'''
622621
_test_unstack((6,), 0, 'int32')
623-
_test_unstack((2,6), 1, 'float64')
622+
_test_unstack((2, 6), 1, 'float64')
624623
# negative axis
625-
_test_unstack((1,4), -1, 'int32')
626-
_test_unstack((3,6,4), -2, 'float32')
624+
_test_unstack((1, 4), -1, 'int32')
625+
_test_unstack((3, 6, 4), -2, 'float32')
627626

628627

629628
#######################################################################
@@ -863,6 +862,22 @@ def test_forward_logical():
863862
test_logical_not()
864863

865864

865+
#######################################################################
866+
# Where, Select
867+
# -------------
868+
def test_where():
869+
''' Where: return elements depending on conditions'''
870+
with tf.Graph().as_default():
871+
with tf.Session() as sess:
872+
input1 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name='input1')
873+
input2 = tf.placeholder(tf.int32, shape=[1, 4, 4, 3], name='input2')
874+
mask = input1 > input2
875+
tf.where(mask, input1 + 1, input2 * 2)
876+
in_data1 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32")
877+
in_data2 = np.random.uniform(0, 10, size=(1, 4, 4, 3)).astype("uint32")
878+
compare_tf_with_tvm([in_data1, in_data2], ['input1:0', 'input2:0'], 'Select:0')
879+
880+
866881
#######################################################################
867882
# Inception V3
868883
# ------------
@@ -1299,3 +1314,4 @@ def test_forward_rel_ops():
12991314
# Relational ops
13001315
test_forward_rel_ops()
13011316
test_forward_logical()
1317+
test_where()

0 commit comments

Comments
 (0)