Skip to content

Commit c700d83

Browse files
Improve dpctl.memory._memory coverage
1 parent c9a4e9f commit c700d83

File tree

1 file changed

+89
-24
lines changed

1 file changed

+89
-24
lines changed

dpctl/tests/test_sycl_usm.py

Lines changed: 89 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,20 @@ def __sycl_usm_array_interface(self):
4848
not has_sycl_platforms(),
4949
reason="No SYCL devices except the default host device.",
5050
)
51-
def test_memory_create():
51+
def test_memory_create(memory_ctor):
52+
import sys
53+
5254
nbytes = 1024
53-
queue = dpctl.get_current_queue()
54-
mobj = MemoryUSMShared(nbytes, alignment=64, queue=queue)
55+
queue = dpctl.SyclQueue()
56+
mobj = memory_ctor(nbytes, alignment=64, queue=queue)
5557
assert mobj.nbytes == nbytes
5658
assert hasattr(mobj, "__sycl_usm_array_interface__")
59+
assert len(mobj) == nbytes
60+
assert mobj.size == nbytes
61+
assert mobj._context == queue.sycl_context
62+
assert type(repr(mobj)) is str
63+
assert type(bytes(mobj)) is bytes
64+
assert sys.getsizeof(mobj) > nbytes
5765

5866

5967
@pytest.mark.skipif(
@@ -69,7 +77,7 @@ def test_memory_create_with_np():
6977

7078
def _create_memory():
7179
nbytes = 1024
72-
queue = dpctl.get_current_queue()
80+
queue = dpctl.SyclQueue()
7381
mobj = MemoryUSMShared(nbytes, alignment=64, queue=queue)
7482
return mobj
7583

@@ -90,38 +98,36 @@ def test_memory_without_context():
9098

9199
# Without context
92100
assert mobj.get_usm_type() == "shared"
101+
assert mobj.get_usm_type(syclobj=dpctl.SyclContext()) == "shared"
93102

94103

95104
@pytest.mark.skipif(not has_cpu(), reason="No SYCL CPU device available.")
96105
def test_memory_cpu_context():
97106
mobj = _create_memory()
98107

99-
# CPU context
100-
with dpctl.device_context("opencl:cpu:0"):
101-
# type respective to the context in which
102-
# memory was created
103-
usm_type = mobj.get_usm_type()
104-
assert usm_type == "shared"
108+
# type respective to the context in which
109+
# memory was created
110+
usm_type = mobj.get_usm_type()
111+
assert usm_type == "shared"
105112

106-
current_queue = dpctl.get_current_queue()
107-
# type as view from current queue
108-
usm_type = mobj.get_usm_type(current_queue)
109-
# type can be unknown if current queue is
110-
# not in the same SYCL context
111-
assert usm_type in ["unknown", "shared"]
113+
cpu_queue = dpctl.SyclQueue("cpu")
114+
# type as view from CPU queue
115+
usm_type = mobj.get_usm_type(cpu_queue)
116+
# type can be unknown if current queue is
117+
# not in the same SYCL context
118+
assert usm_type in ["unknown", "shared"]
112119

113120

114121
@pytest.mark.skipif(not has_gpu(), reason="No OpenCL GPU queues available")
115122
def test_memory_gpu_context():
116123
mobj = _create_memory()
117124

118125
# GPU context
119-
with dpctl.device_context("opencl:gpu:0"):
120-
usm_type = mobj.get_usm_type()
121-
assert usm_type == "shared"
122-
current_queue = dpctl.get_current_queue()
123-
usm_type = mobj.get_usm_type(current_queue)
124-
assert usm_type in ["unknown", "shared"]
126+
usm_type = mobj.get_usm_type()
127+
assert usm_type == "shared"
128+
gpu_queue = dpctl.SyclQueue("opencl:gpu")
129+
usm_type = mobj.get_usm_type(gpu_queue)
130+
assert usm_type in ["unknown", "shared"]
125131

126132

127133
@pytest.mark.skipif(
@@ -166,10 +172,10 @@ def test_zero_copy():
166172
not has_sycl_platforms(),
167173
reason="No SYCL devices except the default host device.",
168174
)
169-
def test_pickling():
175+
def test_pickling(memory_ctor):
170176
import pickle
171177

172-
mobj = _create_memory()
178+
mobj = memory_ctor(1024, alignment=64)
173179
host_src_obj = _create_host_buf(mobj.nbytes)
174180
mobj.copy_from_host(host_src_obj)
175181

@@ -185,6 +191,22 @@ def test_pickling():
185191
), "Pickling/unpickling should be changing pointer"
186192

187193

194+
@pytest.mark.skipif(
195+
not has_sycl_platforms(),
196+
reason="No SYCL devices except the default host device.",
197+
)
198+
def test_pickling_reconstructor_invalid_type(memory_ctor):
199+
import pickle
200+
201+
mobj = memory_ctor(1024, alignment=64)
202+
good_pickle_bytes = pickle.dumps(mobj)
203+
usm_types = expected_usm_type(memory_ctor).encode("utf-8")
204+
i = good_pickle_bytes.index(usm_types)
205+
bad_pickle_bytes = good_pickle_bytes[:i] + b"u" + good_pickle_bytes[i + 1 :]
206+
with pytest.raises(ValueError):
207+
pickle.loads(bad_pickle_bytes)
208+
209+
188210
@pytest.fixture(params=[MemoryUSMShared, MemoryUSMDevice, MemoryUSMHost])
189211
def memory_ctor(request):
190212
return request.param
@@ -389,3 +411,46 @@ def test_with_constructor(memory_ctor):
389411
syclobj=buf.sycl_device.filter_string,
390412
)
391413
check_view(v)
414+
415+
416+
@pytest.mark.skipif(
417+
not has_sycl_platforms(),
418+
reason="No SYCL devices except the default host device.",
419+
)
420+
def test_cpython_api(memory_ctor):
421+
import ctypes
422+
import sys
423+
424+
mobj = memory_ctor(1024)
425+
mod = sys.modules[mobj.__class__.__module__]
426+
# get capsules storing function pointers
427+
mem_ptr_fn_cap = mod.__pyx_capi__["get_usm_pointer"]
428+
mem_ctx_fn_cap = mod.__pyx_capi__["get_context"]
429+
mem_nby_fn_cap = mod.__pyx_capi__["get_nbytes"]
430+
# construct Python callable to invoke "get_usm_pointer"
431+
cap_ptr_fn = ctypes.pythonapi.PyCapsule_GetPointer
432+
cap_ptr_fn.restype = ctypes.c_void_p
433+
cap_ptr_fn.argtypes = [ctypes.py_object, ctypes.c_char_p]
434+
mem_ptr_fn_ptr = cap_ptr_fn(
435+
mem_ptr_fn_cap, b"DPCTLSyclUSMRef (struct Py_MemoryObject *)"
436+
)
437+
mem_ctx_fn_ptr = cap_ptr_fn(
438+
mem_ctx_fn_cap, b"DPCTLSyclContextRef (struct Py_MemoryObject *)"
439+
)
440+
mem_nby_fn_ptr = cap_ptr_fn(
441+
mem_nby_fn_cap, b"size_t (struct Py_MemoryObject *)"
442+
)
443+
callable_maker = ctypes.PYFUNCTYPE(ctypes.c_void_p, ctypes.py_object)
444+
get_ptr_fn = callable_maker(mem_ptr_fn_ptr)
445+
get_ctx_fn = callable_maker(mem_ctx_fn_ptr)
446+
get_nby_fn = callable_maker(mem_nby_fn_ptr)
447+
448+
capi_ptr = get_ptr_fn(mobj)
449+
direct_ptr = mobj._pointer
450+
assert capi_ptr == direct_ptr
451+
capi_ctx_ref = get_ctx_fn(mobj)
452+
direct_ctx_ref = mobj._context.addressof_ref()
453+
assert capi_ctx_ref == direct_ctx_ref
454+
capi_nbytes = get_nby_fn(mobj)
455+
direct_nbytes = mobj.nbytes
456+
assert capi_nbytes == direct_nbytes

0 commit comments

Comments
 (0)