Skip to content

Commit 42c25df

Browse files
Merge pull request #618 from IntelPython/crash-in-usm_ndarray-to_device
Rework the logic in memory's copy_from_device
1 parent 59d4d65 commit 42c25df

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

dpctl/memory/_memory.pyx

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ from dpctl._backend cimport ( # noqa: E211
3333
DPCTLaligned_alloc_device,
3434
DPCTLaligned_alloc_host,
3535
DPCTLaligned_alloc_shared,
36+
DPCTLContext_AreEq,
3637
DPCTLContext_Delete,
3738
DPCTLDevice_Copy,
3839
DPCTLEvent_Delete,
@@ -422,8 +423,10 @@ cdef class _Memory:
422423
the memory of the instance
423424
"""
424425
cdef _USMBufferData src_buf
425-
cdef const char* kind
426426
cdef DPCTLSyclEventRef ERef = NULL
427+
cdef bint same_contexts = False
428+
cdef SyclQueue this_queue = None
429+
cdef SyclQueue src_queue = None
427430

428431
if not hasattr(sycl_usm_ary, '__sycl_usm_array_interface__'):
429432
raise ValueError(
@@ -439,23 +442,28 @@ cdef class _Memory:
439442
"Source object is too large to "
440443
"be accommondated in {} bytes buffer".format(self.nbytes)
441444
)
442-
kind = DPCTLUSM_GetPointerType(
443-
src_buf.p, self.queue.get_sycl_context().get_context_ref())
444-
if (kind == b'unknown'):
445-
copy_via_host(
446-
<void *>self.memory_ptr, self.queue, # dest
447-
<void *>src_buf.p, src_buf.queue, # src
448-
<size_t>src_buf.nbytes
445+
446+
src_queue = src_buf.queue
447+
this_queue = self.queue
448+
same_contexts = DPCTLContext_AreEq(
449+
src_queue.get_sycl_context().get_context_ref(),
450+
this_queue.get_sycl_context().get_context_ref()
449451
)
450-
else:
452+
if (same_contexts):
451453
ERef = DPCTLQueue_Memcpy(
452-
self.queue.get_queue_ref(),
454+
this_queue.get_queue_ref(),
453455
<void *>self.memory_ptr,
454456
<void *>src_buf.p,
455457
<size_t>src_buf.nbytes
456458
)
457459
DPCTLEvent_Wait(ERef)
458460
DPCTLEvent_Delete(ERef)
461+
else:
462+
copy_via_host(
463+
<void *>self.memory_ptr, this_queue, # dest
464+
<void *>src_buf.p, src_queue, # src
465+
<size_t>src_buf.nbytes
466+
)
459467
else:
460468
raise TypeError
461469

0 commit comments

Comments
 (0)