Skip to content

Commit 872f372

Browse files
Added tests to increase coverage of elementwise functions
1 parent a947b3a commit 872f372

File tree

2 files changed

+230
-72
lines changed

2 files changed

+230
-72
lines changed

dpctl/tensor/_elementwise_common.py

Lines changed: 21 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,7 @@ def __init__(self, name, result_type_resolver_fn, unary_dp_impl_fn, docs):
4949

5050
def __call__(self, x, order="K"):
5151
if not isinstance(x, dpt.usm_ndarray):
52-
raise TypeError(
53-
f"Expected :class:`dpctl.tensor.usm_ndarray`, got {type(x)}"
54-
)
52+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
5553
if order not in ["C", "F", "K", "A"]:
5654
order = "K"
5755
buf_dt, res_dt = _find_buf_dtype(
@@ -85,8 +83,6 @@ def __call__(self, x, order="K"):
8583
if order == "K":
8684
r = _empty_like_orderK(buf, res_dt)
8785
else:
88-
if order == "A":
89-
order = "F" if buf.flags.f_contiguous else "C"
9086
r = dpt.empty_like(buf, dtype=res_dt, order=order)
9187

9288
ht, _ = self.unary_fn_(buf, r, sycl_queue=exec_q, depends=[copy_ev])
@@ -142,6 +138,8 @@ def get(self):
142138
def _get_dtype(o, dev):
143139
if isinstance(o, dpt.usm_ndarray):
144140
return o.dtype
141+
if hasattr(o, "__sycl_usm_array_interface__"):
142+
return dpt.asarray(o).dtype
145143
if _is_buffer(o):
146144
host_dt = np.array(o).dtype
147145
dev_dt = _to_device_supported_dtype(host_dt, dev)
@@ -224,13 +222,12 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
224222
return dpt.bool, o2_dtype
225223
if isinstance(o1_dtype, WeakIntegralType):
226224
return dpt.int64, o2_dtype
227-
if isinstance(o1_dtype, WeakInexactType):
228-
if isinstance(o1_dtype.get(), complex):
229-
return (
230-
_to_device_supported_dtype(dpt.complex128, dev),
231-
o2_dtype,
232-
)
233-
return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
225+
if isinstance(o1_dtype.get(), complex):
226+
return (
227+
_to_device_supported_dtype(dpt.complex128, dev),
228+
o2_dtype,
229+
)
230+
return _to_device_supported_dtype(dpt.float64, dev), o2_dtype
234231
else:
235232
return o2_dtype, o2_dtype
236233
elif isinstance(
@@ -243,15 +240,12 @@ def _resolve_weak_types(o1_dtype, o2_dtype, dev):
243240
return o1_dtype, dpt.bool
244241
if isinstance(o2_dtype, WeakIntegralType):
245242
return o1_dtype, dpt.int64
246-
if isinstance(o2_dtype, WeakInexactType):
247-
if isinstance(o2_dtype.get(), complex):
248-
return o1_dtype, _to_device_supported_dtype(
249-
dpt.complex128, dev
250-
)
251-
return (
252-
o1_dtype,
253-
_to_device_supported_dtype(dpt.float64, dev),
254-
)
243+
if isinstance(o2_dtype.get(), complex):
244+
return o1_dtype, _to_device_supported_dtype(dpt.complex128, dev)
245+
return (
246+
o1_dtype,
247+
_to_device_supported_dtype(dpt.float64, dev),
248+
)
255249
else:
256250
return o1_dtype, o1_dtype
257251
else:
@@ -287,10 +281,14 @@ def __repr__(self):
287281
return f"<BinaryElementwiseFunc '{self.name_}'>"
288282

289283
def __call__(self, o1, o2, order="K"):
284+
if order not in ["K", "C", "F", "A"]:
285+
order = "K"
290286
q1, o1_usm_type = _get_queue_usm_type(o1)
291287
q2, o2_usm_type = _get_queue_usm_type(o2)
292288
if q1 is None and q2 is None:
293-
raise ValueError(
289+
raise ExecutionPlacementError(
290+
"Execution placement can not be unambiguously inferred "
291+
"from input arguments. "
294292
"One of the arguments must represent USM allocation and "
295293
"expose `__sycl_usm_array_interface__` property"
296294
)
@@ -415,18 +413,6 @@ def __call__(self, o1, o2, order="K"):
415413
src1, buf2, res_dt, res_usm_type, exec_q
416414
)
417415
else:
418-
if order == "A":
419-
order = (
420-
"F"
421-
if all(
422-
arr.flags.f_contiguous
423-
for arr in (
424-
src1,
425-
buf2,
426-
)
427-
)
428-
else "C"
429-
)
430416
r = dpt.empty(
431417
res_shape,
432418
dtype=res_dt,
@@ -461,18 +447,6 @@ def __call__(self, o1, o2, order="K"):
461447
buf1, src2, res_dt, res_usm_type, exec_q
462448
)
463449
else:
464-
if order == "A":
465-
order = (
466-
"F"
467-
if all(
468-
arr.flags.f_contiguous
469-
for arr in (
470-
buf1,
471-
src2,
472-
)
473-
)
474-
else "C"
475-
)
476450
r = dpt.empty(
477451
res_shape,
478452
dtype=res_dt,
@@ -493,7 +467,7 @@ def __call__(self, o1, o2, order="K"):
493467
ht_.wait()
494468
return r
495469

496-
if order in "KA":
470+
if order in ["K", "A"]:
497471
if src1.flags.f_contiguous and src2.flags.f_contiguous:
498472
order = "F"
499473
else:

dpctl/tests/test_tensor_elementwise.py

Lines changed: 209 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import ctypes
12
import itertools
23

34
import numpy as np
45
import pytest
56

67
import dpctl
78
import dpctl.tensor as dpt
9+
import dpctl.tensor._type_utils as tu
10+
import dpctl.utils
811
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
912

1013
_all_dtypes = [
@@ -26,6 +29,157 @@
2629
_usm_types = ["device", "shared", "host"]
2730

2831

32+
class MockDevice:
33+
def __init__(self, fp16: bool, fp64: bool):
34+
self.has_aspect_fp16 = fp16
35+
self.has_aspect_fp64 = fp64
36+
37+
38+
def _map_to_device_dtype(dt, dev):
39+
return tu._to_device_supported_dtype(dt, dev)
40+
41+
42+
@pytest.mark.parametrize("dtype", _all_dtypes)
43+
def test_type_utils_map_to_device_type(dtype):
44+
for fp64 in [
45+
True,
46+
False,
47+
]:
48+
for fp16 in [True, False]:
49+
dev = MockDevice(fp16, fp64)
50+
dt_in = dpt.dtype(dtype)
51+
dt_out = _map_to_device_dtype(dt_in, dev)
52+
assert isinstance(dt_out, dpt.dtype)
53+
54+
55+
def test_type_util_all_data_types():
56+
for fp64 in [
57+
True,
58+
False,
59+
]:
60+
for fp16 in [True, False]:
61+
r = tu._all_data_types(fp16, fp64)
62+
assert isinstance(r, list)
63+
# 11: bool + 4 signed + 4 unsigned inegral + float32 + complex64
64+
assert len(r) == 11 + int(fp16) + 2 * int(fp64)
65+
66+
67+
def test_type_util_can_cast():
68+
for fp64 in [
69+
True,
70+
False,
71+
]:
72+
for fp16 in [True, False]:
73+
for from_ in _all_dtypes:
74+
for to_ in _all_dtypes:
75+
r = tu._can_cast(
76+
dpt.dtype(from_), dpt.dtype(to_), fp16, fp64
77+
)
78+
assert isinstance(r, bool)
79+
80+
81+
def test_type_utils_empty_like_orderK():
82+
try:
83+
a = dpt.empty((10, 10), dtype=dpt.int32, order="F")
84+
except dpctl.SyclDeviceCreationError:
85+
pytest.skip("No SYCL devices available")
86+
X = tu._empty_like_orderK(a, dpt.int32, a.usm_type, a.device)
87+
assert X.flags["F"]
88+
89+
90+
def test_type_utils_empty_like_orderK_invalid_args():
91+
with pytest.raises(TypeError):
92+
tu._empty_like_orderK([1, 2, 3], dpt.int32, "device", None)
93+
with pytest.raises(TypeError):
94+
tu._empty_like_pair_orderK(
95+
[1, 2, 3],
96+
(
97+
1,
98+
2,
99+
3,
100+
),
101+
dpt.int32,
102+
"device",
103+
None,
104+
)
105+
try:
106+
a = dpt.empty(10, dtype=dpt.int32)
107+
except dpctl.SyclDeviceCreationError:
108+
pytest.skip("No SYCL devices available")
109+
with pytest.raises(TypeError):
110+
tu._empty_like_pair_orderK(
111+
a,
112+
(
113+
1,
114+
2,
115+
3,
116+
),
117+
dpt.int32,
118+
"device",
119+
None,
120+
)
121+
122+
123+
def test_type_utils_find_buf_dtype():
124+
def _denier_fn(dt):
125+
return False
126+
127+
for fp64 in [
128+
True,
129+
False,
130+
]:
131+
for fp16 in [True, False]:
132+
dev = MockDevice(fp16, fp64)
133+
arg_dt = dpt.float64
134+
r = tu._find_buf_dtype(arg_dt, _denier_fn, dev)
135+
assert r == (
136+
None,
137+
None,
138+
)
139+
140+
141+
def test_type_utils_find_buf_dtype2():
142+
def _denier_fn(dt1, dt2):
143+
return False
144+
145+
for fp64 in [
146+
True,
147+
False,
148+
]:
149+
for fp16 in [True, False]:
150+
dev = MockDevice(fp16, fp64)
151+
arg1_dt = dpt.float64
152+
arg2_dt = dpt.complex64
153+
r = tu._find_buf_dtype2(arg1_dt, arg2_dt, _denier_fn, dev)
154+
assert r == (
155+
None,
156+
None,
157+
None,
158+
)
159+
160+
161+
def test_unary_func_arg_validation():
162+
with pytest.raises(TypeError):
163+
dpt.abs([1, 2, 3])
164+
try:
165+
a = dpt.arange(8)
166+
except dpctl.SyclDeviceCreationError:
167+
pytest.skip("No SYCL devices available")
168+
dpt.abs(a, order="invalid")
169+
170+
171+
def test_binary_func_arg_vaidation():
172+
with pytest.raises(dpctl.utils.ExecutionPlacementError):
173+
dpt.add([1, 2, 3], 1)
174+
try:
175+
a = dpt.arange(8)
176+
except dpctl.SyclDeviceCreationError:
177+
pytest.skip("No SYCL devices available")
178+
with pytest.raises(ValueError):
179+
dpt.add(a, Ellipsis)
180+
dpt.add(a, a, order="invalid")
181+
182+
29183
@pytest.mark.parametrize("dtype", _all_dtypes)
30184
def test_abs_out_type(dtype):
31185
q = get_queue_or_skip()
@@ -111,15 +265,7 @@ def test_abs_complex(dtype):
111265
def _compare_dtypes(dt, ref_dt, sycl_queue=None):
112266
assert isinstance(sycl_queue, dpctl.SyclQueue)
113267
dev = sycl_queue.sycl_device
114-
expected_dt = ref_dt
115-
if not dev.has_aspect_fp64:
116-
if expected_dt == dpt.float64:
117-
expected_dt = dpt.float32
118-
elif expected_dt == dpt.complex128:
119-
expected_dt = dpt.complex64
120-
if not dev.has_aspect_fp16:
121-
if expected_dt == dpt.float16:
122-
expected_dt = dpt.float32
268+
expected_dt = _map_to_device_dtype(ref_dt, dev)
123269
return dt == expected_dt
124270

125271

@@ -224,22 +370,60 @@ def test_add_broadcasting():
224370
assert (dpt.asnumpy(r2) == np.arange(1, 6, dtype="i4")[np.newaxis, :]).all()
225371

226372

227-
def _map_to_device_dtype(dt, dev):
228-
if np.issubdtype(dt, np.integer):
229-
return dt
230-
if np.issubdtype(dt, np.floating):
231-
dtc = np.dtype(dt).char
232-
if dtc == "d":
233-
return dt if dev.has_aspect_fp64 else dpt.float32
234-
elif dtc == "e":
235-
return dt if dev.has_aspect_fp16 else dpt.float32
236-
return dt
237-
if np.issubdtype(dt, np.complexfloating):
238-
dtc = np.dtype(dt).char
239-
if dtc == "D":
240-
return dt if dev.has_aspect_fp64 else dpt.complex64
241-
return dt
242-
return dt
373+
@pytest.mark.parametrize("arr_dt", _all_dtypes)
374+
def test_add_python_scalar(arr_dt):
375+
q = get_queue_or_skip()
376+
skip_if_dtype_not_supported(arr_dt, q)
377+
378+
X = dpt.zeros((10, 10), dtype=arr_dt, sycl_queue=q)
379+
py_zeros = (
380+
bool(0),
381+
int(0),
382+
float(0),
383+
complex(0),
384+
np.float32(0),
385+
ctypes.c_int(0),
386+
)
387+
for sc in py_zeros:
388+
R = dpt.add(X, sc)
389+
assert isinstance(R, dpt.usm_ndarray)
390+
R = dpt.add(sc, X)
391+
assert isinstance(R, dpt.usm_ndarray)
392+
393+
394+
class MockArray:
395+
def __init__(self, arr):
396+
self.data_ = arr
397+
398+
@property
399+
def __sycl_usm_array_interface__(self):
400+
return self.data_.__sycl_usm_array_interface__
401+
402+
403+
def test_add_mock_array():
404+
get_queue_or_skip()
405+
a = dpt.arange(10)
406+
b = dpt.ones(10)
407+
c = MockArray(b)
408+
r = dpt.add(a, c)
409+
assert isinstance(r, dpt.usm_ndarray)
410+
411+
412+
def test_add_canary_mock_array():
413+
get_queue_or_skip()
414+
a = dpt.arange(10)
415+
416+
class Canary:
417+
def __init__(self):
418+
pass
419+
420+
@property
421+
def __sycl_usm_array_interface__(self):
422+
return None
423+
424+
c = Canary()
425+
with pytest.raises(ValueError):
426+
dpt.add(a, c)
243427

244428

245429
@pytest.mark.parametrize("dtype", _all_dtypes)

0 commit comments

Comments
 (0)