@@ -34,6 +34,7 @@ from dpctl._backend cimport ( # noqa: E211
34
34
DPCTLaligned_alloc_host,
35
35
DPCTLaligned_alloc_shared,
36
36
DPCTLContext_Delete,
37
+ DPCTLDevice_Copy,
37
38
DPCTLEvent_Delete,
38
39
DPCTLEvent_Wait,
39
40
DPCTLfree_with_queue,
@@ -48,6 +49,7 @@ from dpctl._backend cimport ( # noqa: E211
48
49
DPCTLSyclContextRef,
49
50
DPCTLSyclDeviceRef,
50
51
DPCTLSyclEventRef,
52
+ DPCTLSyclQueueRef,
51
53
DPCTLSyclUSMRef,
52
54
DPCTLUSM_GetPointerDevice,
53
55
DPCTLUSM_GetPointerType,
@@ -138,7 +140,7 @@ cdef class _Memory:
138
140
139
141
cdef _cinit_alloc(self , Py_ssize_t alignment, Py_ssize_t nbytes,
140
142
bytes ptr_type, SyclQueue queue):
141
- cdef DPCTLSyclUSMRef p
143
+ cdef DPCTLSyclUSMRef p = NULL
142
144
143
145
self ._cinit_empty()
144
146
@@ -215,10 +217,12 @@ cdef class _Memory:
215
217
)
216
218
217
219
def __dealloc__ (self ):
218
- if (self .refobj is None and self .memory_ptr):
219
- DPCTLfree_with_queue(
220
- self .memory_ptr, self .queue.get_queue_ref()
221
- )
220
+ if (self .refobj is None ):
221
+ if self .memory_ptr:
222
+ if (type (self .queue) is SyclQueue):
223
+ DPCTLfree_with_queue(
224
+ self .memory_ptr, self .queue.get_queue_ref()
225
+ )
222
226
self ._cinit_empty()
223
227
224
228
cdef _getbuffer(self , Py_buffer * buffer , int flags):
@@ -267,7 +271,7 @@ cdef class _Memory:
267
271
property _queue :
268
272
"""
269
273
:class:`dpctl.SyclQueue` with :class:`dpctl.SyclContext` the
270
- USM pointer is bound to and :class:`dpctl.SyclDevice` it was
274
+ USM allocation is bound to and :class:`dpctl.SyclDevice` it was
271
275
allocated on.
272
276
"""
273
277
def __get__ (self ):
@@ -477,8 +481,10 @@ cdef class _Memory:
477
481
cdef DPCTLSyclDeviceRef dref = DPCTLUSM_GetPointerDevice(
478
482
p, ctx.get_context_ref()
479
483
)
480
-
481
- return SyclDevice._create(dref)
484
+ cdef DPCTLSyclDeviceRef dref_copy = DPCTLDevice_Copy(dref)
485
+ if (dref_copy is NULL ):
486
+ raise RuntimeError (" Could not create a copy of sycl device" )
487
+ return SyclDevice._create(dref_copy) # destroys argsument
482
488
483
489
@staticmethod
484
490
cdef bytes get_pointer_type(DPCTLSyclUSMRef p, SyclContext ctx):
0 commit comments