@@ -373,3 +373,75 @@ def test_datapi_device():
373
373
X .device .sycl_queue
374
374
X .device .sycl_device
375
375
repr (X .device )
376
+
377
+
378
+ def test_pyx_capi ():
379
+ import ctypes
380
+ import sys
381
+
382
+ X = dpt .usm_ndarray (17 )[1 ::2 ]
383
+ mod = sys .modules [X .__class__ .__module__ ]
384
+ # get capsule storign get_context_ref function ptr
385
+ arr_data_fn_cap = mod .__pyx_capi__ ["usm_ndarray_get_data" ]
386
+ arr_ndim_fn_cap = mod .__pyx_capi__ ["usm_ndarray_get_ndim" ]
387
+ arr_shape_fn_cap = mod .__pyx_capi__ ["usm_ndarray_get_shape" ]
388
+ arr_strides_fn_cap = mod .__pyx_capi__ ["usm_ndarray_get_strides" ]
389
+ arr_typenum_fn_cap = mod .__pyx_capi__ ["usm_ndarray_get_typenum" ]
390
+ arr_flags_fn_cap = mod .__pyx_capi__ ["usm_ndarray_get_flags" ]
391
+ arr_queue_ref_fn_cap = mod .__pyx_capi__ ["usm_ndarray_get_queue_ref" ]
392
+ # construct Python callable to invoke these functions
393
+ cap_ptr_fn = ctypes .pythonapi .PyCapsule_GetPointer
394
+ cap_ptr_fn .restype = ctypes .c_void_p
395
+ cap_ptr_fn .argtypes = [ctypes .py_object , ctypes .c_char_p ]
396
+ callable_maker_ptr = ctypes .PYFUNCTYPE (ctypes .c_void_p , ctypes .py_object )
397
+ callable_maker_int = ctypes .PYFUNCTYPE (ctypes .c_int , ctypes .py_object )
398
+ arr_data_fn_ptr = cap_ptr_fn (
399
+ arr_data_fn_cap , b"char *(struct PyUSMArrayObject *)"
400
+ )
401
+ get_data_fn = callable_maker_ptr (arr_data_fn_ptr )
402
+
403
+ arr_ndim_fn_ptr = cap_ptr_fn (
404
+ arr_ndim_fn_cap , b"int (struct PyUSMArrayObject *)"
405
+ )
406
+ get_ndim_fn = callable_maker_int (arr_ndim_fn_ptr )
407
+
408
+ arr_shape_fn_ptr = cap_ptr_fn (
409
+ arr_shape_fn_cap , b"Py_ssize_t *(struct PyUSMArrayObject *)"
410
+ )
411
+ get_shape_fn = callable_maker_ptr (arr_shape_fn_ptr )
412
+
413
+ arr_strides_fn_ptr = cap_ptr_fn (
414
+ arr_strides_fn_cap , b"Py_ssize_t *(struct PyUSMArrayObject *)"
415
+ )
416
+ get_strides_fn = callable_maker_ptr (arr_strides_fn_ptr )
417
+ arr_typenum_fn_ptr = cap_ptr_fn (
418
+ arr_typenum_fn_cap , b"int (struct PyUSMArrayObject *)"
419
+ )
420
+ get_typenum_fn = callable_maker_int (arr_typenum_fn_ptr )
421
+ arr_flags_fn_ptr = cap_ptr_fn (
422
+ arr_flags_fn_cap , b"int (struct PyUSMArrayObject *)"
423
+ )
424
+ get_flags_fn = callable_maker_int (arr_flags_fn_ptr )
425
+ arr_queue_ref_fn_ptr = cap_ptr_fn (
426
+ arr_queue_ref_fn_cap , b"DPCTLSyclQueueRef (struct PyUSMArrayObject *)"
427
+ )
428
+ get_queue_ref_fn = callable_maker_ptr (arr_queue_ref_fn_ptr )
429
+
430
+ r1 = get_data_fn (X )
431
+ sua_iface = X .__sycl_usm_array_interface__
432
+ assert r1 == sua_iface ["data" ][0 ] + sua_iface .get ("offset" ) * X .itemsize
433
+ assert get_ndim_fn (X ) == X .ndim
434
+ c_longlong_p = ctypes .POINTER (ctypes .c_longlong )
435
+ shape0 = ctypes .cast (get_shape_fn (X ), c_longlong_p ).contents .value
436
+ assert shape0 == X .shape [0 ]
437
+ strides0_p = get_strides_fn (X )
438
+ if strides0_p :
439
+ strides0_p = ctypes .cast (strides0_p , c_longlong_p ).contents
440
+ strides0_p = strides0_p .value
441
+ assert strides0_p == 0 or strides0_p == X .strides [0 ]
442
+ typenum = get_typenum_fn (X )
443
+ assert type (typenum ) is int
444
+ flags = get_flags_fn (X )
445
+ assert type (flags ) is int and flags == X .flags
446
+ queue_ref = get_queue_ref_fn (X ) # address of a copy, should be unequal
447
+ assert queue_ref != X .sycl_queue .addressof_ref ()
0 commit comments