Skip to content

Commit 1131ce9

Browse files
Merge pull request #855 from IntelPython/bugfix-arange
Fixed bug causing dpt.arange(0,1,1e-3,dtype="f4") be length 1
2 parents 42afc0c + b1e7dcd commit 1131ce9

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

dpctl/tensor/_ctors.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -471,22 +471,24 @@ def _coerce_and_infer_dt(*args, dt):
471471
raise ValueError(f"Data type {dt} is not supported")
472472

473473

474+
def _round_for_arange(tmp):
475+
k = int(tmp)
476+
if k > 0 and float(k) < tmp:
477+
tmp = tmp + 1
478+
return tmp
479+
480+
474481
def _get_arange_length(start, stop, step):
475482
"Compute length of arange sequence"
476483
span = stop - start
477484
if type(step) in [int, float] and type(span) in [int, float]:
478-
offset = -1 if step > 0 else 1
479-
tmp = 1 + (span + offset) / step
480-
return tmp
485+
return _round_for_arange(span / step)
481486
tmp = span / step
482487
if type(tmp) is complex and tmp.imag == 0:
483488
tmp = tmp.real
484489
else:
485490
return tmp
486-
k = int(tmp)
487-
if k > 0 and float(k) < tmp:
488-
tmp = tmp + 1
489-
return tmp
491+
return _round_for_arange(tmp)
490492

491493

492494
def arange(

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,20 @@ def test_arange(dt):
994994
assert X2.shape == (sz,)
995995

996996

997+
def test_arange_fp():
998+
try:
999+
q = dpctl.SyclQueue()
1000+
except dpctl.SyclQueueCreationError:
1001+
pytest.skip("Queue could not be created")
1002+
1003+
assert dpt.arange(7, 0, -2, dtype="f4", device=q).shape == (4,)
1004+
assert dpt.arange(0, 1, 0.25, dtype="f4", device=q).shape == (4,)
1005+
1006+
if q.sycl_device.has_aspect_fp64:
1007+
assert dpt.arange(7, 0, -2, dtype="f8", device=q).shape == (4,)
1008+
assert dpt.arange(0, 1, 0.25, dtype="f4", device=q).shape == (4,)
1009+
1010+
9971011
@pytest.mark.parametrize(
9981012
"dt",
9991013
_all_dtypes,

0 commit comments

Comments
 (0)