Skip to content

Commit 66c56a4

Browse files
Add an empty value check for dpt.place
1 parent 6dc5479 commit 66c56a4

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

dpctl/tensor/_indexing_functions.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,8 @@ def place(arr, mask, vals):
293293
raise dpctl.utils.ExecutionPlacementError
294294
if arr.shape != mask.shape or vals.ndim != 1:
295295
raise ValueError("Array sizes are not as required")
296+
if vals.size == 0:
297+
raise ValueError("Cannot insert from an empty array!")
296298
cumsum = dpt.empty(mask.size, dtype="i8", sycl_queue=exec_q)
297299
nz_count = ti.mask_positions(mask, cumsum, sycl_queue=exec_q)
298300
if nz_count == 0:

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,16 @@ def test_place_subset():
12011201
assert (dpt.asnumpy(x) == expected).all()
12021202

12031203

1204+
def test_place_empty_vals_error():
1205+
get_queue_or_skip()
1206+
x = dpt.zeros(10, dtype="f4")
1207+
y = dpt.empty((0,), dtype=x.dtype)
1208+
sel = dpt.ones(x.size, dtype="?")
1209+
sel[::2] = False
1210+
with pytest.raises(ValueError):
1211+
dpt.place(x, sel, y)
1212+
1213+
12041214
def test_nonzero():
12051215
get_queue_or_skip()
12061216
x = dpt.concat((dpt.zeros(3), dpt.ones(4), dpt.zeros(3)))

0 commit comments

Comments
 (0)