Skip to content

Commit a6e2417

Browse files
authored
[TF parser] Handle int64 dtype in range (#6918)
1 parent f9d26fb commit a6e2417

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,9 +1454,9 @@ def _impl(inputs, attr, params, mod):
14541454
break
14551455

14561456
if is_symbolic_shape:
1457-
ret = _op.shape_of(inputs[0], dtype="int32")
1457+
ret = _op.shape_of(inputs[0], dtype=attr["out_type"].name)
14581458
else:
1459-
ret = np.array(input_shape, dtype="int32")
1459+
ret = np.array(input_shape, dtype=attr["out_type"].name)
14601460
return ret
14611461

14621462
return _impl
@@ -1862,11 +1862,11 @@ def _impl(inputs, attr, params, mod):
18621862

18631863
dtype = attr["Tidx"].name if "Tidx" in attr else str(start.dtype)
18641864
if isinstance(start, (np.int32, np.int64, int, np.float32, np.float64, float)):
1865-
start = _expr.const(start)
1865+
start = _expr.const(start, dtype=dtype)
18661866
if isinstance(limit, (np.int32, np.int64, int, np.float32, np.float64, float)):
1867-
limit = _expr.const(limit)
1867+
limit = _expr.const(limit, dtype=dtype)
18681868
if isinstance(delta, (np.int32, np.int64, int, np.float32, np.float64, float)):
1869-
delta = _expr.const(delta)
1869+
delta = _expr.const(delta, dtype=dtype)
18701870

18711871
return AttrCvt(
18721872
op_name="arange",

tests/python/frontend/tensorflow/test_forward.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2768,10 +2768,11 @@ def test_forward_unpack():
27682768

27692769
def test_forward_range():
27702770
"""test operator Range"""
2771-
tf.reset_default_graph()
2772-
with tf.Graph().as_default():
2773-
tf.range(1, 18, 3, name="range")
2774-
compare_tf_with_tvm([], [], "range:0")
2771+
for dtype in [tf.int32, tf.int64]:
2772+
tf.reset_default_graph()
2773+
with tf.Graph().as_default():
2774+
tf.range(1, 18, 3, name="range", dtype=dtype)
2775+
compare_tf_with_tvm([], [], "range:0")
27752776

27762777
"""test type assignment for operator Range"""
27772778
tf.reset_default_graph()

0 commit comments

Comments
 (0)