Skip to content

Commit e86a6d6

Browse files
Fixed length computations, added test
``` In [4]: import numpy In [5]: numpy.arange(-5, stop=10**5, step=2.7, dtype=numpy.int64).shape Out[5]: (37039,) In [6]: import dpctl.tensor as dpt In [7]: dpt.arange(-5, stop=10**5, step=2.7, dtype=numpy.int64).shape Out[7]: (50003,) ``` Now lengths are consistent.
1 parent 0346241 commit e86a6d6

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

dpctl/tensor/_ctors.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -491,13 +491,14 @@ def _round_for_arange(tmp):
491491
def _get_arange_length(start, stop, step):
492492
"Compute length of arange sequence"
493493
span = stop - start
494-
if type(step) in [int, float] and type(span) in [int, float]:
494+
if hasattr(step, "__float__") and hasattr(span, "__float__"):
495495
return _round_for_arange(span / step)
496496
tmp = span / step
497-
if type(tmp) is complex and tmp.imag == 0:
497+
if hasattr(tmp, "__complex__"):
498+
tmp = complex(tmp)
498499
tmp = tmp.real
499500
else:
500-
return tmp
501+
tmp = float(tmp)
501502
return _round_for_arange(tmp)
502503

503504

@@ -553,7 +554,7 @@ def arange(
553554
allow_bool=False,
554555
)
555556
try:
556-
tmp = _get_arange_length(start_, stop_, step_)
557+
tmp = _get_arange_length(start, stop, step)
557558
sh = int(tmp)
558559
if sh < 0:
559560
sh = 0

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,13 +1052,18 @@ def test_arange_mixed_types():
10521052
assert x.shape[0] == 3
10531053
assert int(x[1]) == 99 + int(x[0])
10541054

1055-
x = dpt.arange(+2.5, stop=200, step=100, dtype="int32", sycl_queue=q)
1055+
x = dpt.arange(+2.5, stop=200, step=100, dtype="int32", device=x.device)
10561056
assert x.shape[0] == 2
10571057
assert int(x[1]) == 100 + int(x[0])
10581058

1059-
x = dpt.arange(0, stop=np.float32(504), step=100, dtype="f4")
1059+
_stop = np.float32(504)
1060+
x = dpt.arange(0, stop=_stop, step=100, dtype="f4", device=x.device)
10601061
assert x.shape == (6,)
10611062

1063+
# ensure length is determined using uncast parameters
1064+
x = dpt.arange(-5, stop=10**2, step=2.7, dtype="int64", device=x.device)
1065+
assert x.shape == (39,)
1066+
10621067

10631068
@pytest.mark.parametrize(
10641069
"dt",

0 commit comments

Comments
 (0)