Skip to content

Commit a91eaf3

Browse files
Merge pull request #1117 from IntelPython/fix-for-gh-1089
Handle numpy arrays with usm memory underneath in dpctl.tensor.asarray
2 parents 51863d4 + b1355d4 commit a91eaf3

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,22 @@ def _copy_from_numpy_into(dst, np_ary):
9292
raise TypeError(f"Expected numpy.ndarray, got {type(np_ary)}")
9393
if not isinstance(dst, dpt.usm_ndarray):
9494
raise TypeError(f"Expected usm_ndarray, got {type(dst)}")
95-
src_ary = np.broadcast_to(np_ary, dst.shape)
95+
if np_ary.flags["OWNDATA"]:
96+
Xnp = np_ary
97+
else:
98+
# Determine base of input array
99+
base = np_ary.base
100+
while isinstance(base, np.ndarray):
101+
base = base.base
102+
if isinstance(base, dpm._memory._Memory):
103+
# we must perform a copy, since subsequent
104+
# _copy_numpy_ndarray_into_usm_ndarray is implemented using
105+
# sycl::buffer, and using USM-pointers with sycl::buffer
106+
# results is undefined behavior
107+
Xnp = np_ary.copy()
108+
else:
109+
Xnp = np_ary
110+
src_ary = np.broadcast_to(Xnp, dst.shape)
96111
copy_q = dst.sycl_queue
97112
if copy_q.sycl_device.has_aspect_fp64 is False:
98113
src_ary_dt_c = src_ary.dtype.char

0 commit comments

Comments
 (0)