Skip to content

Commit 323988f

Browse files
Reword logic in copy_from_device
In deciding where a copy from device could be done using a direct call to memcpy, we queried usm type of the source allocation with respect to the context to which the destination allocation was bound. The test was assuming that any answer other than 'unknown' indicated that the USM allocation was known and a memcpy could be called. This did not work well when destination was a host device. For example, this crashed: ``` SYCL_ENABLE_HOST_DEVICE=1 python -m pytest dpctl/tests/test_usm_ndarray_ctor.py::test_to_device ``` It passes now.
1 parent a8f4254 commit 323988f

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

dpctl/memory/_memory.pyx

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ from dpctl._backend cimport ( # noqa: E211
3333
DPCTLaligned_alloc_device,
3434
DPCTLaligned_alloc_host,
3535
DPCTLaligned_alloc_shared,
36+
DPCTLContext_AreEq,
3637
DPCTLContext_Delete,
3738
DPCTLDevice_Copy,
3839
DPCTLEvent_Delete,
@@ -422,8 +423,10 @@ cdef class _Memory:
422423
the memory of the instance
423424
"""
424425
cdef _USMBufferData src_buf
425-
cdef const char* kind
426426
cdef DPCTLSyclEventRef ERef = NULL
427+
cdef bint same_contexts = False
428+
cdef SyclQueue this_queue = None
429+
cdef SyclQueue src_queue = None
427430

428431
if not hasattr(sycl_usm_ary, '__sycl_usm_array_interface__'):
429432
raise ValueError(
@@ -439,23 +442,28 @@ cdef class _Memory:
439442
"Source object is too large to "
440443
"be accommondated in {} bytes buffer".format(self.nbytes)
441444
)
442-
kind = DPCTLUSM_GetPointerType(
443-
src_buf.p, self.queue.get_sycl_context().get_context_ref())
444-
if (kind == b'unknown'):
445-
copy_via_host(
446-
<void *>self.memory_ptr, self.queue, # dest
447-
<void *>src_buf.p, src_buf.queue, # src
448-
<size_t>src_buf.nbytes
445+
446+
src_queue = src_buf.queue
447+
this_queue = self.queue
448+
same_contexts = DPCTLContext_AreEq(
449+
src_queue.get_sycl_context().get_context_ref(),
450+
this_queue.get_sycl_context().get_context_ref()
449451
)
450-
else:
452+
if (same_contexts):
451453
ERef = DPCTLQueue_Memcpy(
452-
self.queue.get_queue_ref(),
454+
this_queue.get_queue_ref(),
453455
<void *>self.memory_ptr,
454456
<void *>src_buf.p,
455457
<size_t>src_buf.nbytes
456458
)
457459
DPCTLEvent_Wait(ERef)
458460
DPCTLEvent_Delete(ERef)
461+
else:
462+
copy_via_host(
463+
<void *>self.memory_ptr, this_queue, # dest
464+
<void *>src_buf.p, src_queue, # src
465+
<size_t>src_buf.nbytes
466+
)
459467
else:
460468
raise TypeError
461469

0 commit comments

Comments
 (0)