Skip to content

Commit b3db485

Browse files
Improve dpctl.memory._memory coverage
1 parent c9a4e9f commit b3db485

File tree

1 file changed

+94
-24
lines changed

1 file changed

+94
-24
lines changed

dpctl/tests/test_sycl_usm.py

Lines changed: 94 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,20 @@ def __sycl_usm_array_interface(self):
4848
not has_sycl_platforms(),
4949
reason="No SYCL devices except the default host device.",
5050
)
51-
def test_memory_create():
51+
def test_memory_create(memory_ctor):
52+
import sys
53+
5254
nbytes = 1024
53-
queue = dpctl.get_current_queue()
54-
mobj = MemoryUSMShared(nbytes, alignment=64, queue=queue)
55+
queue = dpctl.SyclQueue()
56+
mobj = memory_ctor(nbytes, alignment=64, queue=queue)
5557
assert mobj.nbytes == nbytes
5658
assert hasattr(mobj, "__sycl_usm_array_interface__")
59+
assert len(mobj) == nbytes
60+
assert mobj.size == nbytes
61+
assert mobj._context == queue.sycl_context
62+
assert type(repr(mobj)) is str
63+
assert type(bytes(mobj)) is bytes
64+
assert sys.getsizeof(mobj) > nbytes
5765

5866

5967
@pytest.mark.skipif(
@@ -69,7 +77,7 @@ def test_memory_create_with_np():
6977

7078
def _create_memory():
7179
nbytes = 1024
72-
queue = dpctl.get_current_queue()
80+
queue = dpctl.SyclQueue()
7381
mobj = MemoryUSMShared(nbytes, alignment=64, queue=queue)
7482
return mobj
7583

@@ -90,38 +98,36 @@ def test_memory_without_context():
9098

9199
# Without context
92100
assert mobj.get_usm_type() == "shared"
101+
assert mobj.get_usm_type(syclobj=dpctl.SyclContext()) == "shared"
93102

94103

95104
@pytest.mark.skipif(not has_cpu(), reason="No SYCL CPU device available.")
96105
def test_memory_cpu_context():
97106
mobj = _create_memory()
98107

99-
# CPU context
100-
with dpctl.device_context("opencl:cpu:0"):
101-
# type respective to the context in which
102-
# memory was created
103-
usm_type = mobj.get_usm_type()
104-
assert usm_type == "shared"
108+
# type respective to the context in which
109+
# memory was created
110+
usm_type = mobj.get_usm_type()
111+
assert usm_type == "shared"
105112

106-
current_queue = dpctl.get_current_queue()
107-
# type as view from current queue
108-
usm_type = mobj.get_usm_type(current_queue)
109-
# type can be unknown if current queue is
110-
# not in the same SYCL context
111-
assert usm_type in ["unknown", "shared"]
113+
cpu_queue = dpctl.SyclQueue("cpu")
114+
# type as view from CPU queue
115+
usm_type = mobj.get_usm_type(cpu_queue)
116+
# type can be unknown if current queue is
117+
# not in the same SYCL context
118+
assert usm_type in ["unknown", "shared"]
112119

113120

114121
@pytest.mark.skipif(not has_gpu(), reason="No OpenCL GPU queues available")
115122
def test_memory_gpu_context():
116123
mobj = _create_memory()
117124

118125
# GPU context
119-
with dpctl.device_context("opencl:gpu:0"):
120-
usm_type = mobj.get_usm_type()
121-
assert usm_type == "shared"
122-
current_queue = dpctl.get_current_queue()
123-
usm_type = mobj.get_usm_type(current_queue)
124-
assert usm_type in ["unknown", "shared"]
126+
usm_type = mobj.get_usm_type()
127+
assert usm_type == "shared"
128+
gpu_queue = dpctl.SyclQueue("opencl:gpu")
129+
usm_type = mobj.get_usm_type(gpu_queue)
130+
assert usm_type in ["unknown", "shared"]
125131

126132

127133
@pytest.mark.skipif(
@@ -166,10 +172,10 @@ def test_zero_copy():
166172
not has_sycl_platforms(),
167173
reason="No SYCL devices except the default host device.",
168174
)
169-
def test_pickling():
175+
def test_pickling(memory_ctor):
170176
import pickle
171177

172-
mobj = _create_memory()
178+
mobj = memory_ctor(1024, alignment=64)
173179
host_src_obj = _create_host_buf(mobj.nbytes)
174180
mobj.copy_from_host(host_src_obj)
175181

@@ -185,6 +191,27 @@ def test_pickling():
185191
), "Pickling/unpickling should be changing pointer"
186192

187193

194+
@pytest.mark.skipif(
195+
not has_sycl_platforms(),
196+
reason="No SYCL devices except the default host device.",
197+
)
198+
def test_pickling_reconstructor_invalid_type(memory_ctor):
199+
import pickle
200+
201+
mobj = memory_ctor(1024, alignment=64)
202+
good_pickle_bytes = pickle.dumps(mobj)
203+
usm_types = expected_usm_type(memory_ctor).encode("utf-8")
204+
i = good_pickle_bytes.index(usm_types)
205+
assert i > 4
206+
bad_pickle_bytes = (
207+
good_pickle_bytes[: i - 4]
208+
+ b"\x07\x00\x00\x00unknown"
209+
+ good_pickle_bytes[i + len(usm_types) :]
210+
)
211+
with pytest.raises(ValueError):
212+
pickle.loads(bad_pickle_bytes)
213+
214+
188215
@pytest.fixture(params=[MemoryUSMShared, MemoryUSMDevice, MemoryUSMHost])
189216
def memory_ctor(request):
190217
return request.param
@@ -389,3 +416,46 @@ def test_with_constructor(memory_ctor):
389416
syclobj=buf.sycl_device.filter_string,
390417
)
391418
check_view(v)
419+
420+
421+
@pytest.mark.skipif(
422+
not has_sycl_platforms(),
423+
reason="No SYCL devices except the default host device.",
424+
)
425+
def test_cpython_api(memory_ctor):
426+
import ctypes
427+
import sys
428+
429+
mobj = memory_ctor(1024)
430+
mod = sys.modules[mobj.__class__.__module__]
431+
# get capsules storing function pointers
432+
mem_ptr_fn_cap = mod.__pyx_capi__["get_usm_pointer"]
433+
mem_ctx_fn_cap = mod.__pyx_capi__["get_context"]
434+
mem_nby_fn_cap = mod.__pyx_capi__["get_nbytes"]
435+
# construct Python callable to invoke "get_usm_pointer"
436+
cap_ptr_fn = ctypes.pythonapi.PyCapsule_GetPointer
437+
cap_ptr_fn.restype = ctypes.c_void_p
438+
cap_ptr_fn.argtypes = [ctypes.py_object, ctypes.c_char_p]
439+
mem_ptr_fn_ptr = cap_ptr_fn(
440+
mem_ptr_fn_cap, b"DPCTLSyclUSMRef (struct Py_MemoryObject *)"
441+
)
442+
mem_ctx_fn_ptr = cap_ptr_fn(
443+
mem_ctx_fn_cap, b"DPCTLSyclContextRef (struct Py_MemoryObject *)"
444+
)
445+
mem_nby_fn_ptr = cap_ptr_fn(
446+
mem_nby_fn_cap, b"size_t (struct Py_MemoryObject *)"
447+
)
448+
callable_maker = ctypes.PYFUNCTYPE(ctypes.c_void_p, ctypes.py_object)
449+
get_ptr_fn = callable_maker(mem_ptr_fn_ptr)
450+
get_ctx_fn = callable_maker(mem_ctx_fn_ptr)
451+
get_nby_fn = callable_maker(mem_nby_fn_ptr)
452+
453+
capi_ptr = get_ptr_fn(mobj)
454+
direct_ptr = mobj._pointer
455+
assert capi_ptr == direct_ptr
456+
capi_ctx_ref = get_ctx_fn(mobj)
457+
direct_ctx_ref = mobj._context.addressof_ref()
458+
assert capi_ctx_ref == direct_ctx_ref
459+
capi_nbytes = get_nby_fn(mobj)
460+
direct_nbytes = mobj.nbytes
461+
assert capi_nbytes == direct_nbytes

0 commit comments

Comments
 (0)