Skip to content

Commit cd6e804

Browse files
Deployed copy-and-cast kernel to copy from numpy to usm_ndarray
1 parent 966f48b commit cd6e804

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -78,18 +78,12 @@ def _copy_from_numpy_into(dst, np_ary):
7878
"Copies `np_ary` into `dst` of type :class:`dpctl.tensor.usm_ndarray"
7979
if not isinstance(np_ary, np.ndarray):
8080
raise TypeError("Expected numpy.ndarray, got {}".format(type(np_ary)))
81-
src_ary = np.broadcast_to(np.asarray(np_ary, dtype=dst.dtype), dst.shape)
82-
if src_ary.size and (dst.flags & 1) and src_ary.flags["C"]:
83-
dpm.as_usm_memory(dst).copy_from_host(src_ary.reshape((-1,)).view("u1"))
84-
return
85-
if src_ary.size and (dst.flags & 2) and src_ary.flags["F"]:
86-
dpm.as_usm_memory(dst).copy_from_host(src_ary.reshape((-1,)).view("u1"))
87-
return
88-
for i in range(dst.size):
89-
mi = np.unravel_index(i, dst.shape)
90-
host_buf = np.array(src_ary[mi], ndmin=1).view("u1")
91-
usm_mem = dpm.as_usm_memory(dst[mi])
92-
usm_mem.copy_from_host(host_buf)
81+
if not isinstance(dst, dpt.usm_ndarray):
82+
raise TypeError("Expected usm_ndarray, got {}".format(type(dst)))
83+
src_ary = np.broadcast_to(np_ary, dst.shape)
84+
ti._copy_numpy_ndarray_into_usm_ndarray(
85+
src=src_ary, dst=dst, sycl_queue=dst.sycl_queue
86+
)
9387

9488

9589
def from_numpy(np_ary, device=None, usm_type="device", sycl_queue=None):

0 commit comments

Comments
 (0)