Skip to content

Commit 2723947

Browse files
Merge pull request #945 from IntelPython/fix-arange-issues
Fix arange issues
2 parents 59980a2 + 2e78bbb commit 2723947

File tree

4 files changed

+171
-175
lines changed

4 files changed

+171
-175
lines changed

dpctl/tensor/_ctors.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -483,21 +483,22 @@ def _coerce_and_infer_dt(*args, dt, sycl_queue, err_msg, allow_bool=False):
483483

484484
def _round_for_arange(tmp):
485485
k = int(tmp)
486-
if k > 0 and float(k) < tmp:
486+
if k >= 0 and float(k) < tmp:
487487
tmp = tmp + 1
488488
return tmp
489489

490490

491491
def _get_arange_length(start, stop, step):
492492
"Compute length of arange sequence"
493493
span = stop - start
494-
if type(step) in [int, float] and type(span) in [int, float]:
494+
if hasattr(step, "__float__") and hasattr(span, "__float__"):
495495
return _round_for_arange(span / step)
496496
tmp = span / step
497-
if type(tmp) is complex and tmp.imag == 0:
497+
if hasattr(tmp, "__complex__"):
498+
tmp = complex(tmp)
498499
tmp = tmp.real
499500
else:
500-
return tmp
501+
tmp = float(tmp)
501502
return _round_for_arange(tmp)
502503

503504

@@ -536,13 +537,18 @@ def arange(
536537
if stop is None:
537538
stop = start
538539
start = 0
540+
if step is None:
541+
step = 1
539542
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
540543
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
541-
(start, stop, step,), dt = _coerce_and_infer_dt(
544+
is_bool = False
545+
if dtype:
546+
is_bool = (dtype is bool) or (dpt.dtype(dtype) == dpt.bool)
547+
(start_, stop_, step_), dt = _coerce_and_infer_dt(
542548
start,
543549
stop,
544550
step,
545-
dt=dtype,
551+
dt=dpt.int8 if is_bool else dtype,
546552
sycl_queue=sycl_queue,
547553
err_msg="start, stop, and step must be Python scalars",
548554
allow_bool=False,
@@ -554,18 +560,40 @@ def arange(
554560
sh = 0
555561
except TypeError:
556562
sh = 0
563+
if is_bool and sh > 2:
564+
raise ValueError("no fill-function for boolean data type")
557565
res = dpt.usm_ndarray(
558566
(sh,),
559567
dtype=dt,
560568
buffer=usm_type,
561569
order="C",
562570
buffer_ctor_kwargs={"queue": sycl_queue},
563571
)
564-
_step = (start + step) - start
565-
_step = dt.type(_step)
566-
_start = dt.type(start)
572+
sc_ty = dt.type
573+
_first = sc_ty(start)
574+
if sh > 1:
575+
_second = sc_ty(start + step)
576+
if dt in [dpt.uint8, dpt.uint16, dpt.uint32, dpt.uint64]:
577+
int64_ty = dpt.int64.type
578+
_step = int64_ty(_second) - int64_ty(_first)
579+
else:
580+
_step = _second - _first
581+
_step = sc_ty(_step)
582+
else:
583+
_step = sc_ty(1)
584+
_start = _first
567585
hev, _ = ti._linspace_step(_start, _step, res, sycl_queue)
568586
hev.wait()
587+
if is_bool:
588+
res_out = dpt.usm_ndarray(
589+
(sh,),
590+
dtype=dpt.bool,
591+
buffer=usm_type,
592+
order="C",
593+
buffer_ctor_kwargs={"queue": sycl_queue},
594+
)
595+
res_out[:] = res
596+
res = res_out
569597
return res
570598

571599

dpctl/tests/helper/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,18 @@
1919

2020
from ._helper import (
2121
create_invalid_capsule,
22+
get_queue_or_skip,
2223
has_cpu,
2324
has_gpu,
2425
has_sycl_platforms,
26+
skip_if_dtype_not_supported,
2527
)
2628

2729
__all__ = [
2830
"create_invalid_capsule",
2931
"has_cpu",
3032
"has_gpu",
3133
"has_sycl_platforms",
34+
"get_queue_or_skip",
35+
"skip_if_dtype_not_supported",
3236
]

dpctl/tests/helper/_helper.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import pytest
18+
1719
import dpctl
1820

1921

@@ -39,3 +41,38 @@ def create_invalid_capsule():
3941
ctor.restype = ctypes.py_object
4042
ctor.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
4143
return ctor(id(ctor), b"invalid", 0)
44+
45+
46+
def get_queue_or_skip(args=tuple()):
47+
try:
48+
q = dpctl.SyclQueue(*args)
49+
except dpctl.SyclQueueCreationError:
50+
pytest.skip(f"Queue could not be created from {args}")
51+
return q
52+
53+
54+
def skip_if_dtype_not_supported(dt, q_or_dev):
55+
import dpctl.tensor as dpt
56+
57+
dt = dpt.dtype(dt)
58+
if type(q_or_dev) is dpctl.SyclQueue:
59+
dev = q_or_dev.sycl_device
60+
elif type(q_or_dev) is dpctl.SyclDevice:
61+
dev = q_or_dev
62+
else:
63+
raise TypeError(
64+
"Expected dpctl.SyclQueue or dpctl.SyclDevice, "
65+
f"got {type(q_or_dev)}"
66+
)
67+
dev_has_dp = dev.has_aspect_fp64
68+
if dev_has_dp is False and dt in [dpt.float64, dpt.complex128]:
69+
pytest.skip(
70+
f"{dev.name} does not support double precision floating point types"
71+
)
72+
dev_has_hp = dev.has_aspect_fp16
73+
if dev_has_hp is False and dt in [
74+
dpt.float16,
75+
]:
76+
pytest.skip(
77+
f"{dev.name} does not support half precision floating point type"
78+
)

0 commit comments

Comments
 (0)