Skip to content

Commit 52773aa

Browse files
Merge pull request #949 from IntelPython/fix-issue-948
Fixed issue in asarray_from_numpy for uint64 dtype.
2 parents 2723947 + eb9b652 commit 52773aa

File tree

2 files changed

+35
-8
lines changed

2 files changed

+35
-8
lines changed

dpctl/tensor/_ctors.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,35 @@ def _asarray_from_usm_ndarray(
186186
order=order,
187187
buffer_ctor_kwargs={"queue": copy_q},
188188
)
189-
# FIXME: call copy_to when implemented
190-
res[(slice(None, None, None),) * res.ndim] = usm_ndary
189+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
190+
src=usm_ndary, dst=res, sycl_queue=copy_q
191+
)
192+
hev.wait()
191193
return res
192194

193195

196+
def _map_to_device_dtype(dt, q):
197+
if dt.char == "?" or np.issubdtype(dt, np.integer):
198+
return dt
199+
d = q.sycl_device
200+
dtc = dt.char
201+
if np.issubdtype(dt, np.floating):
202+
if dtc == "f":
203+
return dt
204+
else:
205+
if dtc == "d" and d.has_aspect_fp64:
206+
return dt
207+
if dtc == "h" and d.has_aspect_fp16:
208+
return dt
209+
return dpt.dtype("f4")
210+
elif np.issubdtype(dt, np.complexfloating):
211+
if dtc == "F":
212+
return dt
213+
if dtc == "D" and d.has_aspect_fp64:
214+
return dt
215+
return dpt.dtype("c8")
216+
217+
194218
def _asarray_from_numpy_ndarray(
195219
ary, dtype=None, usm_type=None, sycl_queue=None, order="K"
196220
):
@@ -205,10 +229,8 @@ def _asarray_from_numpy_ndarray(
205229
"Please convert the input to an array with numeric data type."
206230
)
207231
if dtype is None:
208-
ary_dtype = ary.dtype
209-
dtype = _get_dtype(dtype, copy_q, ref_type=ary_dtype)
210-
if dtype.itemsize > ary_dtype.itemsize:
211-
dtype = ary_dtype
232+
# deduce device-representable output data type
233+
dtype = _map_to_device_dtype(ary.dtype, copy_q)
212234
f_contig = ary.flags["F"]
213235
c_contig = ary.flags["C"]
214236
fc_contig = f_contig or c_contig
@@ -244,8 +266,7 @@ def _asarray_from_numpy_ndarray(
244266
order=order,
245267
buffer_ctor_kwargs={"queue": copy_q},
246268
)
247-
# FIXME: call copy_to when implemented
248-
res[(slice(None, None, None),) * res.ndim] = ary
269+
res[...] = ary
249270
return res
250271

251272

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,3 +1458,9 @@ def test_flags():
14581458
f.writable
14591459
# check comparison with generic types
14601460
f == Ellipsis
1461+
1462+
1463+
def test_asarray_uint64():
1464+
Xnp = np.ndarray(1, dtype=np.uint64)
1465+
X = dpt.asarray(Xnp)
1466+
assert X.dtype == Xnp.dtype

0 commit comments

Comments
 (0)