Skip to content

Commit 6308866

Browse files
Merge pull request #1065 from IntelPython/reorder-usm-type-enum-values
Reordered DPCTL_USM_HOST and DPCTL_USM_SHARED
2 parents 598eff0 + b5b0be0 commit 6308866

File tree

5 files changed

+93
-14
lines changed

5 files changed

+93
-14
lines changed

dpctl/_backend.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ cdef extern from "syclinterface/dpctl_sycl_enum_types.h":
3737
ctypedef enum _usm_type 'DPCTLSyclUSMType':
3838
_USM_UNKNOWN 'DPCTL_USM_UNKNOWN'
3939
_USM_DEVICE 'DPCTL_USM_DEVICE'
40-
_USM_HOST 'DPCTL_USM_HOST'
4140
_USM_SHARED 'DPCTL_USM_SHARED'
41+
_USM_HOST 'DPCTL_USM_HOST'
4242

4343
ctypedef enum _backend_type 'DPCTLSyclBackendType':
4444
_ALL_BACKENDS 'DPCTL_ALL_BACKENDS'

dpctl/memory/_memory.pxd

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ in dpctl.memory._memory.pyx.
2222
2323
"""
2424

25-
from .._backend cimport DPCTLSyclQueueRef, DPCTLSyclUSMRef
25+
from .._backend cimport DPCTLSyclQueueRef, DPCTLSyclUSMRef, _usm_type
2626
from .._sycl_context cimport SyclContext
2727
from .._sycl_device cimport SyclDevice
2828
from .._sycl_queue cimport SyclQueue
@@ -57,6 +57,8 @@ cdef public api class _Memory [object Py_MemoryObject, type Py_MemoryType]:
5757
@staticmethod
5858
cdef bytes get_pointer_type(DPCTLSyclUSMRef p, SyclContext ctx)
5959
@staticmethod
60+
cdef _usm_type get_pointer_type_enum(DPCTLSyclUSMRef p, SyclContext ctx)
61+
@staticmethod
6062
cdef object create_from_usm_pointer_size_qref(
6163
DPCTLSyclUSMRef USMRef,
6264
Py_ssize_t nbytes,

dpctl/memory/_memory.pyx

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,39 @@ cdef class _Memory:
376376
"SyclContext or SyclQueue"
377377
)
378378

379+
def get_usm_type_enum(self, syclobj=None):
380+
"""
381+
get_usm_type(syclobj=None)
382+
383+
Returns the type of USM allocation using Sycl context carried by
384+
`syclobj` keyword argument. Value of None is understood to query
385+
against `self.sycl_context` - the context used to create the
386+
allocation.
387+
"""
388+
cdef const char* kind
389+
cdef SyclContext ctx
390+
cdef SyclQueue q
391+
if syclobj is None:
392+
ctx = self._context
393+
return _Memory.get_pointer_type_enum(
394+
self.memory_ptr, ctx
395+
)
396+
elif isinstance(syclobj, SyclContext):
397+
ctx = <SyclContext>(syclobj)
398+
return _Memory.get_pointer_type_enum(
399+
self.memory_ptr, ctx
400+
)
401+
elif isinstance(syclobj, SyclQueue):
402+
q = <SyclQueue>(syclobj)
403+
ctx = q.get_sycl_context()
404+
return _Memory.get_pointer_type_enum(
405+
self.memory_ptr, ctx
406+
)
407+
raise TypeError(
408+
"syclobj keyword can be either None, or an instance of "
409+
"SyclContext or SyclQueue"
410+
)
411+
379412
cpdef copy_to_host(self, obj=None):
380413
"""
381414
Copy content of instance's memory into memory of ``obj``, or allocate
@@ -553,13 +586,35 @@ cdef class _Memory:
553586
)
554587
if usm_ty == _usm_type._USM_DEVICE:
555588
return b'device'
556-
elif usm_ty == _usm_type._USM_HOST:
557-
return b'host'
558589
elif usm_ty == _usm_type._USM_SHARED:
559590
return b'shared'
591+
elif usm_ty == _usm_type._USM_HOST:
592+
return b'host'
560593
else:
561594
return b'unknown'
562595

596+
@staticmethod
597+
cdef _usm_type get_pointer_type_enum(DPCTLSyclUSMRef p, SyclContext ctx):
598+
"""
599+
get_pointer_type(p, ctx)
600+
601+
Gives the SYCL(TM) USM pointer type, using ``sycl::get_pointer_type``,
602+
returning an enum value.
603+
604+
Args:
605+
p: DPCTLSyclUSMRef
606+
A pointer to test the type of.
607+
ctx: :class:`dpctl.SyclContext`
608+
Python object providing :class:`dpctl.SyclContext` against
609+
which to query for the pointer type.
610+
Returns:
611+
An enum value corresponding to the type of allocation.
612+
"""
613+
cdef _usm_type usm_ty = DPCTLUSM_GetPointerType(
614+
p, ctx.get_context_ref()
615+
)
616+
return usm_ty
617+
563618
@staticmethod
564619
cdef object create_from_usm_pointer_size_qref(
565620
DPCTLSyclUSMRef USMRef, Py_ssize_t nbytes,

dpctl/tests/test_sycl_usm.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def test_pickling_reconstructor_invalid_type(memory_ctor):
201201

202202
mobj = memory_ctor(1024, alignment=64)
203203
good_pickle_bytes = pickle.dumps(mobj)
204-
usm_types = expected_usm_type(memory_ctor).encode("utf-8")
204+
usm_types = expected_usm_type_str(memory_ctor).encode("utf-8")
205205
i = good_pickle_bytes.rfind(usm_types)
206206
bad_pickle_bytes = good_pickle_bytes[:i] + b"u" + good_pickle_bytes[i + 1 :]
207207
with pytest.raises(ValueError):
@@ -213,7 +213,7 @@ def memory_ctor(request):
213213
return request.param
214214

215215

216-
def expected_usm_type(ctor):
216+
def expected_usm_type_str(ctor):
217217
mapping = {
218218
MemoryUSMShared: "shared",
219219
MemoryUSMDevice: "device",
@@ -222,6 +222,15 @@ def expected_usm_type(ctor):
222222
return mapping.get(ctor, "unknown")
223223

224224

225+
def expected_usm_type_enum(ctor):
226+
mapping = {
227+
MemoryUSMShared: 2,
228+
MemoryUSMDevice: 1,
229+
MemoryUSMHost: 3,
230+
}
231+
return mapping.get(ctor, 0)
232+
233+
225234
@pytest.mark.skipif(
226235
not has_sycl_platforms(),
227236
reason="No SYCL devices except the default host device.",
@@ -230,7 +239,8 @@ def test_create_with_size_and_alignment_and_queue(memory_ctor):
230239
q = dpctl.SyclQueue()
231240
m = memory_ctor(1024, alignment=64, queue=q)
232241
assert m.nbytes == 1024
233-
assert m.get_usm_type() == expected_usm_type(memory_ctor)
242+
assert m.get_usm_type() == expected_usm_type_str(memory_ctor)
243+
assert m.get_usm_type_enum() == expected_usm_type_enum(memory_ctor)
234244

235245

236246
@pytest.mark.skipif(
@@ -241,7 +251,8 @@ def test_create_with_size_and_queue(memory_ctor):
241251
q = dpctl.SyclQueue()
242252
m = memory_ctor(1024, queue=q)
243253
assert m.nbytes == 1024
244-
assert m.get_usm_type() == expected_usm_type(memory_ctor)
254+
assert m.get_usm_type() == expected_usm_type_str(memory_ctor)
255+
assert m.get_usm_type_enum() == expected_usm_type_enum(memory_ctor)
245256

246257

247258
@pytest.mark.skipif(
@@ -251,17 +262,28 @@ def test_create_with_size_and_queue(memory_ctor):
251262
def test_create_with_size_and_alignment(memory_ctor):
252263
m = memory_ctor(1024, alignment=64)
253264
assert m.nbytes == 1024
254-
assert m.get_usm_type() == expected_usm_type(memory_ctor)
265+
assert m.get_usm_type() == expected_usm_type_str(memory_ctor)
266+
assert m.get_usm_type_enum() == expected_usm_type_enum(memory_ctor)
255267

256268

257269
@pytest.mark.skipif(
258270
not has_sycl_platforms(),
259271
reason="No SYCL devices except the default host device.",
260272
)
261-
def test_create_with_only_size(memory_ctor):
262-
m = memory_ctor(1024)
273+
def test_usm_type_execeptions():
274+
ctor = MemoryUSMDevice
275+
m = ctor(1024)
263276
assert m.nbytes == 1024
264-
assert m.get_usm_type() == expected_usm_type(memory_ctor)
277+
q = m.sycl_queue
278+
assert m.get_usm_type(syclobj=q) == expected_usm_type_str(ctor)
279+
assert m.get_usm_type_enum(syclobj=q) == expected_usm_type_enum(ctor)
280+
ctx = q.sycl_context
281+
assert m.get_usm_type(syclobj=ctx) == expected_usm_type_str(ctor)
282+
assert m.get_usm_type_enum(syclobj=ctx) == expected_usm_type_enum(ctor)
283+
with pytest.raises(TypeError):
284+
m.get_usm_type(syclobj=Ellipsis)
285+
with pytest.raises(TypeError):
286+
m.get_usm_type_enum(syclobj=list())
265287

266288

267289
@pytest.mark.skipif(

libsyclinterface/include/dpctl_sycl_enum_types.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ typedef enum
3939
{
4040
DPCTL_USM_UNKNOWN = 0,
4141
DPCTL_USM_DEVICE,
42-
DPCTL_USM_HOST,
43-
DPCTL_USM_SHARED
42+
DPCTL_USM_SHARED,
43+
DPCTL_USM_HOST
4444
} DPCTLSyclUSMType;
4545

4646
/*!

0 commit comments

Comments
 (0)