Skip to content

Commit f8cfaa7

Browse files
authored
Merge pull request #1727 from IntelPython/resolve-gh-1723
`extract` and `_extract_impl` now coerce USM type
2 parents 734fd8b + 4f6524d commit f8cfaa7

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,9 @@ def _extract_impl(ary, ary_mask, axis=0):
707707
raise TypeError(
708708
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}"
709709
)
710+
dst_usm_type = dpctl.utils.get_coerced_usm_type(
711+
(ary.usm_type, ary_mask.usm_type)
712+
)
710713
exec_q = dpctl.utils.get_execution_queue(
711714
(ary.sycl_queue, ary_mask.sycl_queue)
712715
)
@@ -733,7 +736,7 @@ def _extract_impl(ary, ary_mask, axis=0):
733736
)
734737
dst_shape = ary.shape[:pp] + (mask_count,) + ary.shape[pp + mask_nd :]
735738
dst = dpt.empty(
736-
dst_shape, dtype=ary.dtype, usm_type=ary.usm_type, device=ary.device
739+
dst_shape, dtype=ary.dtype, usm_type=dst_usm_type, device=ary.device
737740
)
738741
if dst.size == 0:
739742
return dst

0 commit comments

Comments
 (0)