@@ -48,12 +48,20 @@ def __sycl_usm_array_interface(self):
48
48
not has_sycl_platforms (),
49
49
reason = "No SYCL devices except the default host device." ,
50
50
)
51
- def test_memory_create ():
51
+ def test_memory_create (memory_ctor ):
52
+ import sys
53
+
52
54
nbytes = 1024
53
- queue = dpctl .get_current_queue ()
54
- mobj = MemoryUSMShared (nbytes , alignment = 64 , queue = queue )
55
+ queue = dpctl .SyclQueue ()
56
+ mobj = memory_ctor (nbytes , alignment = 64 , queue = queue )
55
57
assert mobj .nbytes == nbytes
56
58
assert hasattr (mobj , "__sycl_usm_array_interface__" )
59
+ assert len (mobj ) == nbytes
60
+ assert mobj .size == nbytes
61
+ assert mobj ._context == queue .sycl_context
62
+ assert type (repr (mobj )) is str
63
+ assert type (bytes (mobj )) is bytes
64
+ assert sys .getsizeof (mobj ) > nbytes
57
65
58
66
59
67
@pytest .mark .skipif (
@@ -69,7 +77,7 @@ def test_memory_create_with_np():
69
77
70
78
def _create_memory ():
71
79
nbytes = 1024
72
- queue = dpctl .get_current_queue ()
80
+ queue = dpctl .SyclQueue ()
73
81
mobj = MemoryUSMShared (nbytes , alignment = 64 , queue = queue )
74
82
return mobj
75
83
@@ -90,38 +98,36 @@ def test_memory_without_context():
90
98
91
99
# Without context
92
100
assert mobj .get_usm_type () == "shared"
101
+ assert mobj .get_usm_type (syclobj = dpctl .SyclContext ()) == "shared"
93
102
94
103
95
104
@pytest .mark .skipif (not has_cpu (), reason = "No SYCL CPU device available." )
96
105
def test_memory_cpu_context ():
97
106
mobj = _create_memory ()
98
107
99
- # CPU context
100
- with dpctl .device_context ("opencl:cpu:0" ):
101
- # type respective to the context in which
102
- # memory was created
103
- usm_type = mobj .get_usm_type ()
104
- assert usm_type == "shared"
108
+ # type respective to the context in which
109
+ # memory was created
110
+ usm_type = mobj .get_usm_type ()
111
+ assert usm_type == "shared"
105
112
106
- current_queue = dpctl .get_current_queue ( )
107
- # type as view from current queue
108
- usm_type = mobj .get_usm_type (current_queue )
109
- # type can be unknown if current queue is
110
- # not in the same SYCL context
111
- assert usm_type in ["unknown" , "shared" ]
113
+ cpu_queue = dpctl .SyclQueue ( "cpu" )
114
+ # type as view from CPU queue
115
+ usm_type = mobj .get_usm_type (cpu_queue )
116
+ # type can be unknown if current queue is
117
+ # not in the same SYCL context
118
+ assert usm_type in ["unknown" , "shared" ]
112
119
113
120
114
121
@pytest .mark .skipif (not has_gpu (), reason = "No OpenCL GPU queues available" )
115
122
def test_memory_gpu_context ():
116
123
mobj = _create_memory ()
117
124
118
125
# GPU context
119
- with dpctl .device_context ("opencl:gpu:0" ):
120
- usm_type = mobj .get_usm_type ()
121
- assert usm_type == "shared"
122
- current_queue = dpctl .get_current_queue ()
123
- usm_type = mobj .get_usm_type (current_queue )
124
- assert usm_type in ["unknown" , "shared" ]
126
+ usm_type = mobj .get_usm_type ()
127
+ assert usm_type == "shared"
128
+ gpu_queue = dpctl .SyclQueue ("opencl:gpu" )
129
+ usm_type = mobj .get_usm_type (gpu_queue )
130
+ assert usm_type in ["unknown" , "shared" ]
125
131
126
132
127
133
@pytest .mark .skipif (
@@ -166,10 +172,10 @@ def test_zero_copy():
166
172
not has_sycl_platforms (),
167
173
reason = "No SYCL devices except the default host device." ,
168
174
)
169
- def test_pickling ():
175
+ def test_pickling (memory_ctor ):
170
176
import pickle
171
177
172
- mobj = _create_memory ( )
178
+ mobj = memory_ctor ( 1024 , alignment = 64 )
173
179
host_src_obj = _create_host_buf (mobj .nbytes )
174
180
mobj .copy_from_host (host_src_obj )
175
181
@@ -185,6 +191,22 @@ def test_pickling():
185
191
), "Pickling/unpickling should be changing pointer"
186
192
187
193
194
+ @pytest .mark .skipif (
195
+ not has_sycl_platforms (),
196
+ reason = "No SYCL devices except the default host device." ,
197
+ )
198
+ def test_pickling_reconstructor_invalid_type (memory_ctor ):
199
+ import pickle
200
+
201
+ mobj = memory_ctor (1024 , alignment = 64 )
202
+ good_pickle_bytes = pickle .dumps (mobj )
203
+ usm_types = expected_usm_type (memory_ctor ).encode ("utf-8" )
204
+ i = good_pickle_bytes .index (usm_types )
205
+ bad_pickle_bytes = good_pickle_bytes [:i ] + b"u" + good_pickle_bytes [i + 1 :]
206
+ with pytest .raises (ValueError ):
207
+ pickle .loads (bad_pickle_bytes )
208
+
209
+
188
210
@pytest .fixture (params = [MemoryUSMShared , MemoryUSMDevice , MemoryUSMHost ])
189
211
def memory_ctor (request ):
190
212
return request .param
@@ -389,3 +411,46 @@ def test_with_constructor(memory_ctor):
389
411
syclobj = buf .sycl_device .filter_string ,
390
412
)
391
413
check_view (v )
414
+
415
+
416
+ @pytest .mark .skipif (
417
+ not has_sycl_platforms (),
418
+ reason = "No SYCL devices except the default host device." ,
419
+ )
420
+ def test_cpython_api (memory_ctor ):
421
+ import ctypes
422
+ import sys
423
+
424
+ mobj = memory_ctor (1024 )
425
+ mod = sys .modules [mobj .__class__ .__module__ ]
426
+ # get capsules storing function pointers
427
+ mem_ptr_fn_cap = mod .__pyx_capi__ ["get_usm_pointer" ]
428
+ mem_ctx_fn_cap = mod .__pyx_capi__ ["get_context" ]
429
+ mem_nby_fn_cap = mod .__pyx_capi__ ["get_nbytes" ]
430
+ # construct Python callable to invoke "get_usm_pointer"
431
+ cap_ptr_fn = ctypes .pythonapi .PyCapsule_GetPointer
432
+ cap_ptr_fn .restype = ctypes .c_void_p
433
+ cap_ptr_fn .argtypes = [ctypes .py_object , ctypes .c_char_p ]
434
+ mem_ptr_fn_ptr = cap_ptr_fn (
435
+ mem_ptr_fn_cap , b"DPCTLSyclUSMRef (struct Py_MemoryObject *)"
436
+ )
437
+ mem_ctx_fn_ptr = cap_ptr_fn (
438
+ mem_ctx_fn_cap , b"DPCTLSyclContextRef (struct Py_MemoryObject *)"
439
+ )
440
+ mem_nby_fn_ptr = cap_ptr_fn (
441
+ mem_nby_fn_cap , b"size_t (struct Py_MemoryObject *)"
442
+ )
443
+ callable_maker = ctypes .PYFUNCTYPE (ctypes .c_void_p , ctypes .py_object )
444
+ get_ptr_fn = callable_maker (mem_ptr_fn_ptr )
445
+ get_ctx_fn = callable_maker (mem_ctx_fn_ptr )
446
+ get_nby_fn = callable_maker (mem_nby_fn_ptr )
447
+
448
+ capi_ptr = get_ptr_fn (mobj )
449
+ direct_ptr = mobj ._pointer
450
+ assert capi_ptr == direct_ptr
451
+ capi_ctx_ref = get_ctx_fn (mobj )
452
+ direct_ctx_ref = mobj ._context .addressof_ref ()
453
+ assert capi_ctx_ref == direct_ctx_ref
454
+ capi_nbytes = get_nby_fn (mobj )
455
+ direct_nbytes = mobj .nbytes
456
+ assert capi_nbytes == direct_nbytes
0 commit comments