Skip to content

Commit d6733b9

Browse files
Allow x[cond] = non_usm_array
This allows `x[x<0] = 0` to work. Previously, it had to be `x[x<0] = dpt.asarray(0)`.
1 parent d259247 commit d6733b9

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -514,13 +514,17 @@ def _place_impl(ary, ary_mask, vals, axis=0):
514514
raise TypeError(
515515
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}"
516516
)
517-
if not isinstance(vals, dpt.usm_ndarray):
518-
raise TypeError(
519-
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}"
520-
)
521517
exec_q = dpctl.utils.get_execution_queue(
522-
(ary.sycl_queue, ary_mask.sycl_queue, vals.sycl_queue)
518+
(
519+
ary.sycl_queue,
520+
ary_mask.sycl_queue,
521+
)
523522
)
523+
if exec_q is not None:
524+
if not isinstance(vals, dpt.usm_ndarray):
525+
vals = dpt.asarray(vals, dtype=ary.dtype, sycl_queue=exec_q)
526+
else:
527+
exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue))
524528
if exec_q is None:
525529
raise dpctl.utils.ExecutionPlacementError(
526530
"arrays have different associated queues. "

0 commit comments

Comments
 (0)