Skip to content

Commit 39d7901

Browse files
committed
Added tests for where and type utility functions
1 parent 7392756 commit 39d7901

File tree

4 files changed

+219
-17
lines changed

4 files changed

+219
-17
lines changed

dpctl/tensor/_search_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def where(condition, x1, x2):
8989

9090
deps = []
9191
wait_list = []
92-
if x1_dtype is not dst_dtype:
92+
if x1_dtype != dst_dtype:
9393
_x1 = dpt.empty_like(x1, dtype=dst_dtype)
9494
ht_copy1_ev, copy1_ev = ti._copy_usm_ndarray_into_usm_ndarray(
9595
src=x1, dst=_x1, sycl_queue=exec_q
@@ -98,7 +98,7 @@ def where(condition, x1, x2):
9898
deps.append(copy1_ev)
9999
wait_list.append(ht_copy1_ev)
100100

101-
if x2_dtype is not dst_dtype:
101+
if x2_dtype != dst_dtype:
102102
_x2 = dpt.empty_like(x2, dtype=dst_dtype)
103103
ht_copy2_ev, copy2_ev = ti._copy_usm_ndarray_into_usm_ndarray(
104104
src=x2, dst=_x2, sycl_queue=exec_q

dpctl/tensor/_type_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _all_data_types(_fp16, _fp64):
8484
]
8585

8686

87-
def is_maximal_inexact_type(dt: dpt.dtype, _fp16: bool, _fp64: bool):
87+
def _is_maximal_inexact_type(dt: dpt.dtype, _fp16: bool, _fp64: bool):
8888
"""
8989
Return True if data type `dt` is the
9090
maximal size inexact data type
@@ -106,7 +106,7 @@ def _can_cast(from_: dpt.dtype, to_: dpt.dtype, _fp16: bool, _fp64: bool):
106106
if (
107107
from_.kind in "biu"
108108
and to_.kind in "fc"
109-
and is_maximal_inexact_type(to_, _fp16, _fp64)
109+
and _is_maximal_inexact_type(to_, _fp16, _fp64)
110110
):
111111
return True
112112

dpctl/tests/test_type_utils.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2022 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import pytest
18+
19+
import dpctl.tensor as dpt
20+
from dpctl.tensor._type_utils import (
21+
_all_data_types,
22+
_can_cast,
23+
_is_maximal_inexact_type,
24+
)
25+
26+
27+
def test_all_data_types():
28+
fp16_fp64_types = set([dpt.float16, dpt.float64, dpt.complex128])
29+
fp64_types = set([dpt.float64, dpt.complex128])
30+
31+
all_dts = _all_data_types(True, True)
32+
assert fp16_fp64_types.issubset(all_dts)
33+
34+
all_dts = _all_data_types(True, False)
35+
assert dpt.float16 in all_dts
36+
assert not fp64_types.issubset(all_dts)
37+
38+
all_dts = _all_data_types(False, True)
39+
assert dpt.float16 not in all_dts
40+
assert fp64_types.issubset(all_dts)
41+
42+
all_dts = _all_data_types(False, False)
43+
assert not fp16_fp64_types.issubset(all_dts)
44+
45+
46+
@pytest.mark.parametrize("fp16", [True, False])
47+
@pytest.mark.parametrize("fp64", [True, False])
48+
def test_maximal_inexact_types(fp16, fp64):
49+
assert not _is_maximal_inexact_type(dpt.int32, fp16, fp64)
50+
assert fp64 == _is_maximal_inexact_type(dpt.float64, fp16, fp64)
51+
assert fp64 == _is_maximal_inexact_type(dpt.complex128, fp16, fp64)
52+
assert fp64 != _is_maximal_inexact_type(dpt.float32, fp16, fp64)
53+
assert fp64 != _is_maximal_inexact_type(dpt.complex64, fp16, fp64)
54+
55+
56+
def test_can_cast_device():
57+
assert _can_cast(dpt.int64, dpt.float64, True, True)
58+
# if f8 is available, can't cast i8 to f4
59+
assert not _can_cast(dpt.int64, dpt.float32, True, True)
60+
assert not _can_cast(dpt.int64, dpt.float32, False, True)
61+
# should be able to cast to f8 when f2 unavailable
62+
assert _can_cast(dpt.int64, dpt.float64, False, True)
63+
# casting to f4 acceptable when f8 unavailable
64+
assert _can_cast(dpt.int64, dpt.float32, True, False)
65+
assert _can_cast(dpt.int64, dpt.float32, False, False)
66+
# can't safely cast inexact type to inexact type of lesser precision
67+
assert not _can_cast(dpt.float32, dpt.float16, True, False)
68+
assert not _can_cast(dpt.float64, dpt.float32, False, True)

dpctl/tests/test_usm_ndarray_search_functions.py

Lines changed: 147 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from numpy.testing import assert_array_equal
2121

2222
import dpctl.tensor as dpt
23+
from dpctl.tensor._search_functions import _where_result_type
24+
from dpctl.tensor._type_utils import _all_data_types
25+
from dpctl.utils import ExecutionPlacementError
2326

2427
_all_dtypes = [
2528
"u1",
@@ -38,6 +41,12 @@
3841
]
3942

4043

44+
class mock_device:
45+
def __init__(self, fp16, fp64):
46+
self.has_aspect_fp16 = fp16
47+
self.has_aspect_fp64 = fp64
48+
49+
4150
def test_where_basic():
4251
get_queue_or_skip()
4352

@@ -54,7 +63,16 @@ def test_where_basic():
5463
out_expected = dpt.asarray(
5564
[[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0], [1, 1, 1]]
5665
)
66+
assert (dpt.asnumpy(out) == dpt.asnumpy(out_expected)).all()
67+
68+
out = dpt.where(cond, dpt.ones(cond.shape), dpt.zeros(cond.shape))
69+
assert (dpt.asnumpy(out) == dpt.asnumpy(out_expected)).all()
5770

71+
out = dpt.where(
72+
cond,
73+
dpt.ones(cond.shape[0])[:, dpt.newaxis],
74+
dpt.zeros(cond.shape[0])[:, dpt.newaxis],
75+
)
5876
assert (dpt.asnumpy(out) == dpt.asnumpy(out_expected)).all()
5977

6078

@@ -70,6 +88,31 @@ def _dtype_all_close(x1, x2):
7088
return np.allclose(x1, x2)
7189

7290

91+
@pytest.mark.parametrize("dt1", _all_dtypes)
92+
@pytest.mark.parametrize("dt2", _all_dtypes)
93+
@pytest.mark.parametrize("fp16", [True, False])
94+
@pytest.mark.parametrize("fp64", [True, False])
95+
def test_where_result_types(dt1, dt2, fp16, fp64):
96+
dev = mock_device(fp16, fp64)
97+
98+
dt1 = dpt.dtype(dt1)
99+
dt2 = dpt.dtype(dt2)
100+
res_t = _where_result_type(dt1, dt2, dev)
101+
102+
if fp16 and fp64:
103+
assert res_t == dpt.result_type(dt1, dt2)
104+
else:
105+
if res_t:
106+
assert res_t.kind == dpt.result_type(dt1, dt2).kind
107+
else:
108+
# some illegal cases are covered above, but
109+
# this guarantees that _where_result_type
110+
# produces None only when one of the dtypes
111+
# is illegal given fp aspects of device
112+
all_dts = _all_data_types(fp16, fp64)
113+
assert dt1 not in all_dts or dt2 not in all_dts
114+
115+
73116
@pytest.mark.parametrize("dt1", _all_dtypes)
74117
@pytest.mark.parametrize("dt2", _all_dtypes)
75118
def test_where_all_dtypes(dt1, dt2):
@@ -78,17 +121,39 @@ def test_where_all_dtypes(dt1, dt2):
78121
skip_if_dtype_not_supported(dt2, q)
79122

80123
cond = dpt.asarray([False, False, False, True, True], sycl_queue=q)
81-
x1 = dpt.asarray(2, sycl_queue=q)
82-
x2 = dpt.asarray(3, sycl_queue=q)
83-
124+
x1 = dpt.asarray(2, dtype=dt1, sycl_queue=q)
125+
x2 = dpt.asarray(3, dtype=dt2, sycl_queue=q)
84126
res = dpt.where(cond, x1, x2)
127+
85128
res_check = np.asarray([3, 3, 3, 2, 2], dtype=res.dtype)
129+
assert _dtype_all_close(dpt.asnumpy(res), res_check)
86130

87-
dev = q.sycl_device
131+
# contiguous cases
132+
x1 = dpt.full(cond.shape, 2, dtype=dt1, sycl_queue=q)
133+
x2 = dpt.full(cond.shape, 3, dtype=dt2, sycl_queue=q)
134+
res = dpt.where(cond, x1, x2)
135+
assert _dtype_all_close(dpt.asnumpy(res), res_check)
88136

89-
if not dev.has_aspect_fp16 or not dev.has_aspect_fp64:
90-
assert res.dtype.kind == dpt.result_type(x1.dtype, x2.dtype).kind
91137

138+
@pytest.mark.parametrize("dt1", _all_dtypes)
139+
@pytest.mark.parametrize("dt2", _all_dtypes)
140+
def test_where_mask_dtypes(dt1, dt2):
141+
q = get_queue_or_skip()
142+
skip_if_dtype_not_supported(dt1, q)
143+
skip_if_dtype_not_supported(dt2, q)
144+
145+
cond = dpt.asarray([0, 1, 3, 0, 10], dtype=dt1, sycl_queue=q)
146+
x1 = dpt.asarray(2, dtype=dt2, sycl_queue=q)
147+
x2 = dpt.asarray(3, dtype=dt2, sycl_queue=q)
148+
res = dpt.where(cond, x1, x2)
149+
150+
res_check = np.asarray([3, 2, 2, 3, 2], dtype=res.dtype)
151+
assert _dtype_all_close(dpt.asnumpy(res), res_check)
152+
153+
# contiguous cases
154+
x1 = dpt.full(cond.shape, 2, dtype=dt2, sycl_queue=q)
155+
x2 = dpt.full(cond.shape, 3, dtype=dt2, sycl_queue=q)
156+
res = dpt.where(cond, x1, x2)
92157
assert _dtype_all_close(dpt.asnumpy(res), res_check)
93158

94159

@@ -116,12 +181,14 @@ def test_where_empty():
116181

117182
assert_array_equal(dpt.asnumpy(res), res_np)
118183

184+
# check that broadcasting is performed
185+
with pytest.raises(ValueError):
186+
dpt.where(empty, x1, dpt.empty((1, 2)))
187+
119188

120-
@pytest.mark.parametrize("dt", _all_dtypes)
121189
@pytest.mark.parametrize("order", ["C", "F"])
122-
def test_where_contiguous(dt, order):
123-
q = get_queue_or_skip()
124-
skip_if_dtype_not_supported(dt, q)
190+
def test_where_contiguous(order):
191+
get_queue_or_skip()
125192

126193
cond = dpt.asarray(
127194
[
@@ -131,14 +198,81 @@ def test_where_contiguous(dt, order):
131198
[[False, False, False], [True, False, True]],
132199
[[True, True, True], [True, False, True]],
133200
],
134-
sycl_queue=q,
135201
order=order,
136202
)
137203

138-
x1 = dpt.full(cond.shape, 2, dtype=dt, order=order, sycl_queue=q)
139-
x2 = dpt.full(cond.shape, 3, dtype=dt, order=order, sycl_queue=q)
204+
x1 = dpt.full(cond.shape, 2, order=order)
205+
x2 = dpt.full(cond.shape, 3, order=order)
206+
expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2))
207+
res = dpt.where(cond, x1, x2)
208+
209+
assert _dtype_all_close(dpt.asnumpy(res), expected)
210+
211+
212+
def test_where_contiguous1D():
213+
get_queue_or_skip()
214+
215+
cond = dpt.asarray([True, False, True, False, False, True])
216+
217+
x1 = dpt.full(cond.shape, 2)
218+
x2 = dpt.full(cond.shape, 3)
219+
expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2))
220+
res = dpt.where(cond, x1, x2)
221+
assert _dtype_all_close(dpt.asnumpy(res), expected)
222+
223+
# test with complex dtype (branch in kernel)
224+
x1 = dpt.astype(x1, dpt.complex64)
225+
x2 = dpt.astype(x2, dpt.complex64)
226+
expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2))
227+
res = dpt.where(cond, x1, x2)
228+
assert _dtype_all_close(dpt.asnumpy(res), expected)
229+
230+
231+
def test_where_strided():
232+
get_queue_or_skip()
233+
234+
s0, s1 = 4, 9
235+
cond = dpt.reshape(
236+
dpt.asarray(
237+
[True, False, False, False, True, True, False, True, False] * s0
238+
),
239+
(s0, s1),
240+
)[:, ::3]
140241

242+
x1 = dpt.ones((cond.shape[0], cond.shape[1] * 2))[:, ::2]
243+
x2 = dpt.zeros((cond.shape[0], cond.shape[1] * 3))[:, ::3]
141244
expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2))
142245
res = dpt.where(cond, x1, x2)
143246

144247
assert _dtype_all_close(dpt.asnumpy(res), expected)
248+
249+
250+
def test_where_arg_validation():
251+
get_queue_or_skip()
252+
253+
check = dict()
254+
x1 = dpt.empty((1,))
255+
x2 = dpt.empty((1,))
256+
257+
with pytest.raises(TypeError):
258+
dpt.where(check, x1, x2)
259+
with pytest.raises(TypeError):
260+
dpt.where(x1, check, x2)
261+
with pytest.raises(TypeError):
262+
dpt.where(x1, x2, check)
263+
264+
265+
def test_where_compute_follows_data():
266+
q1 = get_queue_or_skip()
267+
q2 = get_queue_or_skip()
268+
q3 = get_queue_or_skip()
269+
270+
x1 = dpt.empty((1,), sycl_queue=q1)
271+
x2 = dpt.empty((1,), sycl_queue=q2)
272+
273+
with pytest.raises(ExecutionPlacementError):
274+
dpt.where(dpt.empty((1,), sycl_queue=q1), x1, x2)
275+
with pytest.raises(ExecutionPlacementError):
276+
dpt.where(dpt.empty((1,), sycl_queue=q3), x1, x2)
277+
with pytest.raises(ExecutionPlacementError):
278+
dpt.where(x1, x1, x2)

0 commit comments

Comments
 (0)