Skip to content

Commit 2f8156c

Browse files
Merge pull request #1101 from IntelPython/fix-boolean-mask-asignment-scalar
Fix boolean mask asignment scalar
2 parents d259247 + c162c2a commit 2f8156c

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -514,13 +514,17 @@ def _place_impl(ary, ary_mask, vals, axis=0):
514514
raise TypeError(
515515
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}"
516516
)
517-
if not isinstance(vals, dpt.usm_ndarray):
518-
raise TypeError(
519-
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}"
520-
)
521517
exec_q = dpctl.utils.get_execution_queue(
522-
(ary.sycl_queue, ary_mask.sycl_queue, vals.sycl_queue)
518+
(
519+
ary.sycl_queue,
520+
ary_mask.sycl_queue,
521+
)
523522
)
523+
if exec_q is not None:
524+
if not isinstance(vals, dpt.usm_ndarray):
525+
vals = dpt.asarray(vals, dtype=ary.dtype, sycl_queue=exec_q)
526+
else:
527+
exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue))
524528
if exec_q is None:
525529
raise dpctl.utils.ExecutionPlacementError(
526530
"arrays have different associated queues. "

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,3 +1206,15 @@ def test_nonzero():
12061206
x = dpt.concat((dpt.zeros(3), dpt.ones(4), dpt.zeros(3)))
12071207
(i,) = dpt.nonzero(x)
12081208
assert (dpt.asnumpy(i) == np.array([3, 4, 5, 6])).all()
1209+
1210+
1211+
def test_assign_scalar():
1212+
get_queue_or_skip()
1213+
x = dpt.arange(-5, 5, dtype="i8")
1214+
cond = dpt.asarray(
1215+
[True, True, True, True, True, False, False, False, False, False]
1216+
)
1217+
x[cond] = 0 # no error expected
1218+
x[dpt.nonzero(cond)] = -1
1219+
expected = np.array([-1, -1, -1, -1, -1, 0, 1, 2, 3, 4], dtype=x.dtype)
1220+
assert (dpt.asnumpy(x) == expected).all()

0 commit comments

Comments
 (0)