Skip to content

Commit 7392756

Browse files
committed
Where changed for empty and F-contiguous input
- Where now outputs an F-contiguous array when all inputs are F-contiguous - Where now outputs a empty 0D array if any input is a 0D empty array - Added tests for these cases Fixed incorrect logic in where test
1 parent a3c95a5 commit 7392756

File tree

2 files changed

+103
-19
lines changed

2 files changed

+103
-19
lines changed

dpctl/tensor/_search_functions.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,21 @@ def where(condition, x1, x2):
7272
x1_dtype = x1.dtype
7373
x2_dtype = x2.dtype
7474
dst_dtype = _where_result_type(x1_dtype, x2_dtype, exec_q.sycl_device)
75-
76-
if condition.size == 0:
77-
return dpt.asarray(
78-
(), dtype=dst_dtype, usm_type=dst_usm_type, sycl_queue=exec_q
75+
if dst_dtype is None:
76+
raise TypeError(
77+
"function 'where' does not support input "
78+
f"types ({x1_dtype}, {x2_dtype}), "
79+
"and the inputs could not be safely coerced "
80+
"to any supported types according to the casting rule ''safe''."
7981
)
8082

8183
res_shape = _broadcast_shapes(condition, x1, x2)
8284

85+
if condition.size == 0:
86+
return dpt.empty(
87+
res_shape, dtype=dst_dtype, usm_type=dst_usm_type, sycl_queue=exec_q
88+
)
89+
8390
deps = []
8491
wait_list = []
8592
if x1_dtype is not dst_dtype:
@@ -104,8 +111,25 @@ def where(condition, x1, x2):
104111
x1 = dpt.broadcast_to(x1, res_shape)
105112
x2 = dpt.broadcast_to(x2, res_shape)
106113

114+
# dst is F-contiguous when all inputs are F contiguous
115+
# otherwise, defaults to C-contiguous
116+
if all(
117+
(
118+
condition.flags.fnc,
119+
x1.flags.fnc,
120+
x2.flags.fnc,
121+
)
122+
):
123+
order = "F"
124+
else:
125+
order = "C"
126+
107127
dst = dpt.empty(
108-
res_shape, dtype=dst_dtype, usm_type=dst_usm_type, sycl_queue=exec_q
128+
res_shape,
129+
dtype=dst_dtype,
130+
order=order,
131+
usm_type=dst_usm_type,
132+
sycl_queue=exec_q,
109133
)
110134

111135
hev, _ = ti._where(

dpctl/tests/test_usm_ndarray_search_functions.py

Lines changed: 74 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040

4141
def test_where_basic():
42-
get_queue_or_skip
42+
get_queue_or_skip()
4343

4444
cond = dpt.asarray(
4545
[
@@ -58,27 +58,87 @@ def test_where_basic():
5858
assert (dpt.asnumpy(out) == dpt.asnumpy(out_expected)).all()
5959

6060

61+
def _dtype_all_close(x1, x2):
62+
if np.issubdtype(x2.dtype, np.floating) or np.issubdtype(
63+
x2.dtype, np.complexfloating
64+
):
65+
x2_dtype = x2.dtype
66+
return np.allclose(
67+
x1, x2, atol=np.finfo(x2_dtype).eps, rtol=np.finfo(x2_dtype).eps
68+
)
69+
else:
70+
return np.allclose(x1, x2)
71+
72+
6173
@pytest.mark.parametrize("dt1", _all_dtypes)
6274
@pytest.mark.parametrize("dt2", _all_dtypes)
6375
def test_where_all_dtypes(dt1, dt2):
6476
q = get_queue_or_skip()
6577
skip_if_dtype_not_supported(dt1, q)
6678
skip_if_dtype_not_supported(dt2, q)
6779

68-
cond_np = np.arange(5) > 2
69-
x1_np = np.asarray(2, dtype=dt1)
70-
x2_np = np.asarray(3, dtype=dt2)
71-
72-
cond = dpt.asarray(cond_np, sycl_queue=q)
73-
x1 = dpt.asarray(x1_np, sycl_queue=q)
74-
x2 = dpt.asarray(x2_np, sycl_queue=q)
80+
cond = dpt.asarray([False, False, False, True, True], sycl_queue=q)
81+
x1 = dpt.asarray(2, sycl_queue=q)
82+
x2 = dpt.asarray(3, sycl_queue=q)
7583

7684
res = dpt.where(cond, x1, x2)
77-
res_np = np.where(cond_np, x1_np, x2_np)
85+
res_check = np.asarray([3, 3, 3, 2, 2], dtype=res.dtype)
7886

79-
if res.dtype != res_np.dtype:
80-
assert res.dtype.kind == res_np.dtype.kind
81-
assert_array_equal(dpt.asnumpy(res).astype(res_np.dtype), res_np)
87+
dev = q.sycl_device
8288

83-
else:
84-
assert_array_equal(dpt.asnumpy(res), res_np)
89+
if not dev.has_aspect_fp16 or not dev.has_aspect_fp64:
90+
assert res.dtype.kind == dpt.result_type(x1.dtype, x2.dtype).kind
91+
92+
assert _dtype_all_close(dpt.asnumpy(res), res_check)
93+
94+
95+
def test_where_empty():
96+
# check that numpy returns same results when
97+
# handling empty arrays
98+
get_queue_or_skip()
99+
100+
empty = dpt.empty(0)
101+
m = dpt.asarray(True)
102+
x1 = dpt.asarray(1)
103+
x2 = dpt.asarray(2)
104+
res = dpt.where(empty, x1, x2)
105+
106+
empty_np = np.empty(0)
107+
m_np = dpt.asnumpy(m)
108+
x1_np = dpt.asnumpy(x1)
109+
x2_np = dpt.asnumpy(x2)
110+
res_np = np.where(empty_np, x1_np, x2_np)
111+
112+
assert_array_equal(dpt.asnumpy(res), res_np)
113+
114+
res = dpt.where(m, empty, x2)
115+
res_np = np.where(m_np, empty_np, x2_np)
116+
117+
assert_array_equal(dpt.asnumpy(res), res_np)
118+
119+
120+
@pytest.mark.parametrize("dt", _all_dtypes)
121+
@pytest.mark.parametrize("order", ["C", "F"])
122+
def test_where_contiguous(dt, order):
123+
q = get_queue_or_skip()
124+
skip_if_dtype_not_supported(dt, q)
125+
126+
cond = dpt.asarray(
127+
[
128+
[[True, False, False], [False, True, True]],
129+
[[False, True, False], [True, False, True]],
130+
[[False, False, True], [False, False, True]],
131+
[[False, False, False], [True, False, True]],
132+
[[True, True, True], [True, False, True]],
133+
],
134+
sycl_queue=q,
135+
order=order,
136+
)
137+
138+
x1 = dpt.full(cond.shape, 2, dtype=dt, order=order, sycl_queue=q)
139+
x2 = dpt.full(cond.shape, 3, dtype=dt, order=order, sycl_queue=q)
140+
141+
expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2))
142+
res = dpt.where(cond, x1, x2)
143+
144+
assert _dtype_all_close(dpt.asnumpy(res), expected)

0 commit comments

Comments
 (0)