@@ -33,6 +33,7 @@ from dpctl._backend cimport ( # noqa: E211
33
33
DPCTLaligned_alloc_device,
34
34
DPCTLaligned_alloc_host,
35
35
DPCTLaligned_alloc_shared,
36
+ DPCTLContext_AreEq,
36
37
DPCTLContext_Delete,
37
38
DPCTLDevice_Copy,
38
39
DPCTLEvent_Delete,
@@ -422,8 +423,10 @@ cdef class _Memory:
422
423
the memory of the instance
423
424
"""
424
425
cdef _USMBufferData src_buf
425
- cdef const char * kind
426
426
cdef DPCTLSyclEventRef ERef = NULL
427
+ cdef bint same_contexts = False
428
+ cdef SyclQueue this_queue = None
429
+ cdef SyclQueue src_queue = None
427
430
428
431
if not hasattr (sycl_usm_ary, ' __sycl_usm_array_interface__' ):
429
432
raise ValueError (
@@ -439,23 +442,28 @@ cdef class _Memory:
439
442
" Source object is too large to "
440
443
" be accommondated in {} bytes buffer" .format(self .nbytes)
441
444
)
442
- kind = DPCTLUSM_GetPointerType(
443
- src_buf.p, self .queue.get_sycl_context().get_context_ref())
444
- if (kind == b' unknown' ):
445
- copy_via_host(
446
- < void * > self .memory_ptr, self .queue, # dest
447
- < void * > src_buf.p, src_buf.queue, # src
448
- < size_t> src_buf.nbytes
445
+
446
+ src_queue = src_buf.queue
447
+ this_queue = self .queue
448
+ same_contexts = DPCTLContext_AreEq(
449
+ src_queue.get_sycl_context().get_context_ref(),
450
+ this_queue.get_sycl_context().get_context_ref()
449
451
)
450
- else :
452
+ if (same_contexts) :
451
453
ERef = DPCTLQueue_Memcpy(
452
- self .queue .get_queue_ref(),
454
+ this_queue .get_queue_ref(),
453
455
< void * > self .memory_ptr,
454
456
< void * > src_buf.p,
455
457
< size_t> src_buf.nbytes
456
458
)
457
459
DPCTLEvent_Wait(ERef)
458
460
DPCTLEvent_Delete(ERef)
461
+ else :
462
+ copy_via_host(
463
+ < void * > self .memory_ptr, this_queue, # dest
464
+ < void * > src_buf.p, src_queue, # src
465
+ < size_t> src_buf.nbytes
466
+ )
459
467
else :
460
468
raise TypeError
461
469
0 commit comments