Skip to content

Commit d24634a

Browse files
anijain2305Ubuntu
andauthored
TF argmax - handling int64 datatype (#6674)
Co-authored-by: Ubuntu <ubuntu@ip-172-31-0-202.us-west-2.compute.internal>
1 parent d5728bd commit d24634a

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,11 @@ def _impl(inputs, attr, params, mod):
146146
raise TypeError(
147147
"Unsupported argument for `{}` : `axis` should be a constant".format(func_name)
148148
)
149-
return func(inputs[0], axis=axis_input_value, keepdims=False)
149+
out = func(inputs[0], axis=axis_input_value, keepdims=False)
150+
dtype = attr["output_type"].name
151+
if dtype != "int32":
152+
out = _op.cast(out, dtype=dtype)
153+
return out
150154

151155
return _impl
152156

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,16 +1601,16 @@ def _test_argx(func, data, **kwargs):
16011601

16021602
with tf.Graph().as_default():
16031603
inp = array_ops.placeholder(shape=data.shape, dtype=data.dtype, name="c0")
1604-
func(inp, name="argx0", output_type=tf.int32, **kwargs)
1605-
1604+
func(inp, name="argx0", **kwargs)
16061605
compare_tf_with_tvm(data, "c0:0", "argx0:0")
16071606

16081607

16091608
def test_forward_argminmax():
1610-
for axis in [None, 0, 1, 2]:
1611-
data = np.random.uniform(size=(8, 4, 9)).astype("float32")
1612-
_test_argx(tf.argmax, data=data, axis=axis)
1613-
_test_argx(tf.argmin, data=data, axis=axis)
1609+
for output_type in [tf.int64, tf.int32]:
1610+
for axis in [None, 0, 1, 2]:
1611+
data = np.random.uniform(size=(8, 4, 9)).astype("float32")
1612+
_test_argx(tf.argmax, data=data, axis=axis, output_type=output_type)
1613+
_test_argx(tf.argmin, data=data, axis=axis, output_type=output_type)
16141614

16151615

16161616
#######################################################################

0 commit comments

Comments
 (0)