Skip to content

Commit 0346241

Browse files
_get_arange_length must see parameters with canonical types
Example now giving expected output: ``` assert dpt.arange(np.float32(0), stop=np.float32(504.0), step=np.float32(100), dtype=np.float32).shape == (5,) ``` Test added.
1 parent 83c7578 commit 0346241

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

dpctl/tensor/_ctors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def arange(
553553
allow_bool=False,
554554
)
555555
try:
556-
tmp = _get_arange_length(start, stop, step)
556+
tmp = _get_arange_length(start_, stop_, step_)
557557
sh = int(tmp)
558558
if sh < 0:
559559
sh = 0

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,6 +1056,9 @@ def test_arange_mixed_types():
10561056
assert x.shape[0] == 2
10571057
assert int(x[1]) == 100 + int(x[0])
10581058

1059+
x = dpt.arange(0, stop=np.float32(504), step=100, dtype="f4")
1060+
assert x.shape == (6,)
1061+
10591062

10601063
@pytest.mark.parametrize(
10611064
"dt",

0 commit comments

Comments
 (0)