Skip to content

Commit 4f20687

Browse files
Fixed a bug
```python import dpctl.tensor X = dpt.usm_ndarray((4,4), "i4") X[:] = 4 # used to raise an error ```
1 parent a3e209f commit 4f20687

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,10 @@ def copy_from_numpy(np_ary, usm_type="device", queue=None):
9999
def copy_from_numpy_into(dst, np_ary):
100100
if not isinstance(np_ary, np.ndarray):
101101
raise TypeError("Expected numpy.ndarray, got {}".format(type(np_ary)))
102+
src_ary = np.broadcast_to(np.asarray(np_ary, dtype=dst.dtype), dst.shape)
102103
for i in range(dst.size):
103104
mi = np.unravel_index(i, dst.shape)
104-
host_buf = np.array(np_ary[mi], dtype=dst.dtype, ndmin=1).view("u1")
105+
host_buf = np.array(src_ary[mi], ndmin=1).view("u1")
105106
usm_mem = dpm.as_usm_memory(dst[mi])
106107
usm_mem.copy_from_host(host_buf)
107108

0 commit comments

Comments
 (0)