Skip to content

Commit d6b7c66

Browse files
Merge pull request #951 from IntelPython/support-cross-device-asarray
dpt.asarray can now migrate arrays across devices
2 parents 651b516 + 721f9b2 commit d6b7c66

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

dpctl/tensor/_ctors.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,15 @@ def _asarray_from_usm_ndarray(
186186
order=order,
187187
buffer_ctor_kwargs={"queue": copy_q},
188188
)
189-
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
190-
src=usm_ndary, dst=res, sycl_queue=copy_q
191-
)
192-
hev.wait()
189+
eq = dpctl.utils.get_execution_queue([usm_ndary.sycl_queue, copy_q])
190+
if eq is not None:
191+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
192+
src=usm_ndary, dst=res, sycl_queue=eq
193+
)
194+
hev.wait()
195+
else:
196+
tmp = dpt.asnumpy(usm_ndary)
197+
res[...] = tmp
193198
return res
194199

195200

dpctl/tests/test_tensor_asarray.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numpy as np
1818
import pytest
19+
from helper import get_queue_or_skip
1920

2021
import dpctl
2122
import dpctl.tensor as dpt
@@ -193,10 +194,7 @@ def test_asarray_scalars():
193194

194195

195196
def test_asarray_copy_false():
196-
try:
197-
q = dpctl.SyclQueue()
198-
except dpctl.SyclQueueCreationError:
199-
pytest.skip("Could not create a queue")
197+
q = get_queue_or_skip()
200198
rng = np.random.default_rng()
201199
Xnp = rng.integers(low=-255, high=255, size=(10, 4), dtype=np.int64)
202200
X = dpt.from_numpy(Xnp, usm_type="device", sycl_queue=q)
@@ -229,10 +227,15 @@ def test_asarray_copy_false():
229227

230228

231229
def test_asarray_invalid_dtype():
232-
try:
233-
q = dpctl.SyclQueue()
234-
except dpctl.SyclQueueCreationError:
235-
pytest.skip("Could not create a queue")
230+
q = get_queue_or_skip()
236231
Xnp = np.array([1, 2, 3], dtype=object)
237232
with pytest.raises(TypeError):
238233
dpt.asarray(Xnp, sycl_queue=q)
234+
235+
236+
def test_asarray_cross_device():
237+
q = get_queue_or_skip()
238+
qprof = dpctl.SyclQueue(property="enable_profiling")
239+
x = dpt.empty(10, dtype="i8", sycl_queue=q)
240+
y = dpt.asarray(x, sycl_queue=qprof)
241+
assert y.sycl_queue == qprof

0 commit comments

Comments
 (0)