Skip to content

Commit e0f08ad

Browse files
committed
Adjusted where test data types
- Set dtypes used in tests (where possible) to i4 to prevent allocation failures on GPUs without double support - Made tests for strided where kernel data more robust
1 parent 2f3af4a commit e0f08ad

File tree

2 files changed

+41
-22
lines changed

2 files changed

+41
-22
lines changed

dpctl/tensor/libtensor/source/where.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,8 @@ py_where(dpctl::tensor::usm_ndarray condition,
166166

167167
auto where_ev = contig_fn(exec_q, nelems, cond_data, x1_data, x2_data,
168168
dst_data, depends);
169-
sycl::event ht_ev = dpctl::utils::keep_args_alive(
170-
exec_q, {x1, x2, dst, condition}, {where_ev});
169+
sycl::event ht_ev =
170+
keep_args_alive(exec_q, {x1, x2, dst, condition}, {where_ev});
171171

172172
return std::make_pair(ht_ev, where_ev);
173173
}

dpctl/tests/test_usm_ndarray_search_functions.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def test_where_basic():
7070

7171
out = dpt.where(
7272
cond,
73-
dpt.ones(cond.shape[0])[:, dpt.newaxis],
74-
dpt.zeros(cond.shape[0])[:, dpt.newaxis],
73+
dpt.ones(cond.shape[0], dtype="i4")[:, dpt.newaxis],
74+
dpt.zeros(cond.shape[0], dtype="i4")[:, dpt.newaxis],
7575
)
7676
assert (dpt.asnumpy(out) == dpt.asnumpy(out_expected)).all()
7777

@@ -162,13 +162,13 @@ def test_where_empty():
162162
# handling empty arrays
163163
get_queue_or_skip()
164164

165-
empty = dpt.empty(0)
165+
empty = dpt.empty(0, dtype="i2")
166166
m = dpt.asarray(True)
167-
x1 = dpt.asarray(1)
168-
x2 = dpt.asarray(2)
167+
x1 = dpt.asarray(1, dtype="i2")
168+
x2 = dpt.asarray(2, dtype="i2")
169169
res = dpt.where(empty, x1, x2)
170170

171-
empty_np = np.empty(0)
171+
empty_np = np.empty(0, dtype="i2")
172172
m_np = dpt.asnumpy(m)
173173
x1_np = dpt.asnumpy(x1)
174174
x2_np = dpt.asnumpy(x2)
@@ -201,8 +201,8 @@ def test_where_contiguous(order):
201201
order=order,
202202
)
203203

204-
x1 = dpt.full(cond.shape, 2, order=order)
205-
x2 = dpt.full(cond.shape, 3, order=order)
204+
x1 = dpt.full(cond.shape, 2, dtype="i4", order=order)
205+
x2 = dpt.full(cond.shape, 3, dtype="i4", order=order)
206206
expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2))
207207
res = dpt.where(cond, x1, x2)
208208

@@ -214,11 +214,11 @@ def test_where_contiguous1D():
214214

215215
cond = dpt.asarray([True, False, True, False, False, True])
216216

217-
x1 = dpt.full(cond.shape, 2)
218-
x2 = dpt.full(cond.shape, 3)
217+
x1 = dpt.full(cond.shape, 2, dtype="i4")
218+
x2 = dpt.full(cond.shape, 3, dtype="i4")
219219
expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2))
220220
res = dpt.where(cond, x1, x2)
221-
assert _dtype_all_close(dpt.asnumpy(res), expected)
221+
assert_array_equal(dpt.asnumpy(res), expected)
222222

223223
# test with complex dtype (branch in kernel)
224224
x1 = dpt.astype(x1, dpt.complex64)
@@ -239,20 +239,39 @@ def test_where_strided():
239239
(s0, s1),
240240
)[:, ::3]
241241

242-
x1 = dpt.ones((cond.shape[0], cond.shape[1] * 2))[:, ::2]
243-
x2 = dpt.zeros((cond.shape[0], cond.shape[1] * 3))[:, ::3]
242+
x1 = dpt.reshape(
243+
dpt.arange(cond.shape[0] * cond.shape[1] * 2, dtype="i4"),
244+
(cond.shape[0], cond.shape[1] * 2),
245+
)[:, ::2]
246+
x2 = dpt.reshape(
247+
dpt.arange(cond.shape[0] * cond.shape[1] * 3, dtype="i4"),
248+
(cond.shape[0], cond.shape[1] * 3),
249+
)[:, ::3]
244250
expected = np.where(dpt.asnumpy(cond), dpt.asnumpy(x1), dpt.asnumpy(x2))
245251
res = dpt.where(cond, x1, x2)
246252

247-
assert _dtype_all_close(dpt.asnumpy(res), expected)
253+
assert_array_equal(dpt.asnumpy(res), expected)
254+
255+
# negative strides
256+
res = dpt.where(cond, dpt.flip(x1), x2)
257+
expected = np.where(
258+
dpt.asnumpy(cond), np.flip(dpt.asnumpy(x1)), dpt.asnumpy(x2)
259+
)
260+
assert_array_equal(dpt.asnumpy(res), expected)
261+
262+
res = dpt.where(dpt.flip(cond), x1, x2)
263+
expected = np.where(
264+
np.flip(dpt.asnumpy(cond)), dpt.asnumpy(x1), dpt.asnumpy(x2)
265+
)
266+
assert_array_equal(dpt.asnumpy(res), expected)
248267

249268

250269
def test_where_arg_validation():
251270
get_queue_or_skip()
252271

253272
check = dict()
254-
x1 = dpt.empty((1,))
255-
x2 = dpt.empty((1,))
273+
x1 = dpt.empty((1,), dtype="i4")
274+
x2 = dpt.empty((1,), dtype="i4")
256275

257276
with pytest.raises(TypeError):
258277
dpt.where(check, x1, x2)
@@ -267,12 +286,12 @@ def test_where_compute_follows_data():
267286
q2 = get_queue_or_skip()
268287
q3 = get_queue_or_skip()
269288

270-
x1 = dpt.empty((1,), sycl_queue=q1)
271-
x2 = dpt.empty((1,), sycl_queue=q2)
289+
x1 = dpt.empty((1,), dtype="i4", sycl_queue=q1)
290+
x2 = dpt.empty((1,), dtype="i4", sycl_queue=q2)
272291

273292
with pytest.raises(ExecutionPlacementError):
274-
dpt.where(dpt.empty((1,), sycl_queue=q1), x1, x2)
293+
dpt.where(dpt.empty((1,), dtype="i4", sycl_queue=q1), x1, x2)
275294
with pytest.raises(ExecutionPlacementError):
276-
dpt.where(dpt.empty((1,), sycl_queue=q3), x1, x2)
295+
dpt.where(dpt.empty((1,), dtype="i4", sycl_queue=q3), x1, x2)
277296
with pytest.raises(ExecutionPlacementError):
278297
dpt.where(x1, x1, x2)

0 commit comments

Comments
 (0)