Skip to content

Commit 0ef2833

Browse files
Added check in deallocation that sycl.queue is of the right type
Fixes an unreferenced crash uncovered by tests added to improve coverage
1 parent 6831d11 commit 0ef2833

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

dpctl/memory/_memory.pyx

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ from dpctl._backend cimport ( # noqa: E211
3434
DPCTLaligned_alloc_host,
3535
DPCTLaligned_alloc_shared,
3636
DPCTLContext_Delete,
37+
DPCTLDevice_Copy,
3738
DPCTLEvent_Delete,
3839
DPCTLEvent_Wait,
3940
DPCTLfree_with_queue,
@@ -48,6 +49,7 @@ from dpctl._backend cimport ( # noqa: E211
4849
DPCTLSyclContextRef,
4950
DPCTLSyclDeviceRef,
5051
DPCTLSyclEventRef,
52+
DPCTLSyclQueueRef,
5153
DPCTLSyclUSMRef,
5254
DPCTLUSM_GetPointerDevice,
5355
DPCTLUSM_GetPointerType,
@@ -138,7 +140,7 @@ cdef class _Memory:
138140

139141
cdef _cinit_alloc(self, Py_ssize_t alignment, Py_ssize_t nbytes,
140142
bytes ptr_type, SyclQueue queue):
141-
cdef DPCTLSyclUSMRef p
143+
cdef DPCTLSyclUSMRef p = NULL
142144

143145
self._cinit_empty()
144146

@@ -215,10 +217,12 @@ cdef class _Memory:
215217
)
216218

217219
def __dealloc__(self):
218-
if (self.refobj is None and self.memory_ptr):
219-
DPCTLfree_with_queue(
220-
self.memory_ptr, self.queue.get_queue_ref()
221-
)
220+
if (self.refobj is None):
221+
if self.memory_ptr:
222+
if (type(self.queue) is SyclQueue):
223+
DPCTLfree_with_queue(
224+
self.memory_ptr, self.queue.get_queue_ref()
225+
)
222226
self._cinit_empty()
223227

224228
cdef _getbuffer(self, Py_buffer *buffer, int flags):
@@ -267,7 +271,7 @@ cdef class _Memory:
267271
property _queue:
268272
"""
269273
:class:`dpctl.SyclQueue` with :class:`dpctl.SyclContext` the
270-
USM pointer is bound to and :class:`dpctl.SyclDevice` it was
274+
USM allocation is bound to and :class:`dpctl.SyclDevice` it was
271275
allocated on.
272276
"""
273277
def __get__(self):
@@ -477,8 +481,10 @@ cdef class _Memory:
477481
cdef DPCTLSyclDeviceRef dref = DPCTLUSM_GetPointerDevice(
478482
p, ctx.get_context_ref()
479483
)
480-
481-
return SyclDevice._create(dref)
484+
cdef DPCTLSyclDeviceRef dref_copy = DPCTLDevice_Copy(dref)
485+
if (dref_copy is NULL):
486+
raise RuntimeError("Could not create a copy of sycl device")
487+
return SyclDevice._create(dref_copy) # destroys argsument
482488

483489
@staticmethod
484490
cdef bytes get_pointer_type(DPCTLSyclUSMRef p, SyclContext ctx):

0 commit comments

Comments
 (0)