Skip to content

Commit 602fd9b

Browse files
anijain2305Ubuntu
authored andcommitted
TF argmax - handling int64 datatype (apache#6674)
Co-authored-by: Ubuntu <ubuntu@ip-172-31-0-202.us-west-2.compute.internal>
1 parent 8eff7f2 commit 602fd9b

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,11 @@ def _impl(inputs, attr, params, mod):
146146
raise TypeError(
147147
"Unsupported argument for `{}` : `axis` should be a constant".format(func_name)
148148
)
149-
return func(inputs[0], axis=axis_input_value, keepdims=False)
149+
out = func(inputs[0], axis=axis_input_value, keepdims=False)
150+
dtype = attr["output_type"].name
151+
if dtype != "int32":
152+
out = _op.cast(out, dtype=dtype)
153+
return out
150154

151155
return _impl
152156

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,16 +1616,16 @@ def _test_argx(func, data, **kwargs):
16161616

16171617
with tf.Graph().as_default():
16181618
inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="c0")
1619-
func(inp, name="argx0", output_type=tf.int32, **kwargs)
1620-
1619+
func(inp, name="argx0", **kwargs)
16211620
compare_tf_with_tvm(data, "c0:0", "argx0:0")
16221621

16231622

16241623
def test_forward_argminmax():
1625-
for axis in [None, 0, 1, 2]:
1626-
data = np.random.uniform(size=(8, 4, 9)).astype("float32")
1627-
_test_argx(tf.argmax, data=data, axis=axis)
1628-
_test_argx(tf.argmin, data=data, axis=axis)
1624+
for output_type in [tf.int64, tf.int32]:
1625+
for axis in [None, 0, 1, 2]:
1626+
data = np.random.uniform(size=(8, 4, 9)).astype("float32")
1627+
_test_argx(tf.argmax, data=data, axis=axis, output_type=output_type)
1628+
_test_argx(tf.argmin, data=data, axis=axis, output_type=output_type)
16291629

16301630

16311631
#######################################################################

0 commit comments

Comments
 (0)