@@ -373,7 +373,11 @@ def test_datapi_device():
373
373
374
374
375
375
def _pyx_capi_fnptr_to_callable (
376
- X , pyx_capi_name , caps_name , fn_restype = ctypes .c_void_p
376
+ X ,
377
+ pyx_capi_name ,
378
+ caps_name ,
379
+ fn_restype = ctypes .c_void_p ,
380
+ fn_argtypes = (ctypes .py_object ,),
377
381
):
378
382
import sys
379
383
@@ -388,7 +392,7 @@ def _pyx_capi_fnptr_to_callable(
388
392
cap_ptr_fn .restype = ctypes .c_void_p
389
393
cap_ptr_fn .argtypes = [ctypes .py_object , ctypes .c_char_p ]
390
394
fn_ptr = cap_ptr_fn (cap , caps_name )
391
- callable_maker_ptr = ctypes .PYFUNCTYPE (fn_restype , ctypes . py_object )
395
+ callable_maker_ptr = ctypes .PYFUNCTYPE (fn_restype , * fn_argtypes )
392
396
return callable_maker_ptr (fn_ptr )
393
397
394
398
@@ -399,6 +403,7 @@ def test_pyx_capi_get_data():
399
403
"UsmNDArray_GetData" ,
400
404
b"char *(struct PyUSMArrayObject *)" ,
401
405
fn_restype = ctypes .c_void_p ,
406
+ fn_argtypes = (ctypes .py_object ,),
402
407
)
403
408
r1 = get_data_fn (X )
404
409
sua_iface = X .__sycl_usm_array_interface__
@@ -412,6 +417,7 @@ def test_pyx_capi_get_shape():
412
417
"UsmNDArray_GetShape" ,
413
418
b"Py_ssize_t *(struct PyUSMArrayObject *)" ,
414
419
fn_restype = ctypes .c_void_p ,
420
+ fn_argtypes = (ctypes .py_object ,),
415
421
)
416
422
c_longlong_p = ctypes .POINTER (ctypes .c_longlong )
417
423
shape0 = ctypes .cast (get_shape_fn (X ), c_longlong_p ).contents .value
@@ -425,6 +431,7 @@ def test_pyx_capi_get_strides():
425
431
"UsmNDArray_GetStrides" ,
426
432
b"Py_ssize_t *(struct PyUSMArrayObject *)" ,
427
433
fn_restype = ctypes .c_void_p ,
434
+ fn_argtypes = (ctypes .py_object ,),
428
435
)
429
436
c_longlong_p = ctypes .POINTER (ctypes .c_longlong )
430
437
strides0_p = get_strides_fn (X )
@@ -441,6 +448,7 @@ def test_pyx_capi_get_ndim():
441
448
"UsmNDArray_GetNDim" ,
442
449
b"int (struct PyUSMArrayObject *)" ,
443
450
fn_restype = ctypes .c_int ,
451
+ fn_argtypes = (ctypes .py_object ,),
444
452
)
445
453
assert get_ndim_fn (X ) == X .ndim
446
454
@@ -452,6 +460,7 @@ def test_pyx_capi_get_typenum():
452
460
"UsmNDArray_GetTypenum" ,
453
461
b"int (struct PyUSMArrayObject *)" ,
454
462
fn_restype = ctypes .c_int ,
463
+ fn_argtypes = (ctypes .py_object ,),
455
464
)
456
465
typenum = get_typenum_fn (X )
457
466
assert type (typenum ) is int
@@ -465,6 +474,7 @@ def test_pyx_capi_get_elemsize():
465
474
"UsmNDArray_GetElementSize" ,
466
475
b"int (struct PyUSMArrayObject *)" ,
467
476
fn_restype = ctypes .c_int ,
477
+ fn_argtypes = (ctypes .py_object ,),
468
478
)
469
479
itemsize = get_elemsize_fn (X )
470
480
assert type (itemsize ) is int
@@ -478,6 +488,7 @@ def test_pyx_capi_get_flags():
478
488
"UsmNDArray_GetFlags" ,
479
489
b"int (struct PyUSMArrayObject *)" ,
480
490
fn_restype = ctypes .c_int ,
491
+ fn_argtypes = (ctypes .py_object ,),
481
492
)
482
493
flags = get_flags_fn (X )
483
494
assert type (flags ) is int and X .flags == flags
@@ -490,6 +501,7 @@ def test_pyx_capi_get_offset():
490
501
"UsmNDArray_GetOffset" ,
491
502
b"Py_ssize_t (struct PyUSMArrayObject *)" ,
492
503
fn_restype = ctypes .c_longlong ,
504
+ fn_argtypes = (ctypes .py_object ,),
493
505
)
494
506
offset = get_offset_fn (X )
495
507
assert type (offset ) is int
@@ -503,11 +515,104 @@ def test_pyx_capi_get_queue_ref():
503
515
"UsmNDArray_GetQueueRef" ,
504
516
b"DPCTLSyclQueueRef (struct PyUSMArrayObject *)" ,
505
517
fn_restype = ctypes .c_void_p ,
518
+ fn_argtypes = (ctypes .py_object ,),
506
519
)
507
520
queue_ref = get_queue_ref_fn (X ) # address of a copy, should be unequal
508
521
assert queue_ref != X .sycl_queue .addressof_ref ()
509
522
510
523
524
+ def test_pyx_capi_make_from_memory ():
525
+ q = get_queue_or_skip ()
526
+ n0 , n1 = 4 , 6
527
+ c_tuple = (ctypes .c_ssize_t * 2 )(n0 , n1 )
528
+ mem = dpm .MemoryUSMShared (n0 * n1 * 4 , queue = q )
529
+ typenum = dpt .dtype ("single" ).num
530
+ any_usm_ndarray = dpt .empty (tuple (), dtype = "i4" , sycl_queue = q )
531
+ make_from_memory_fn = _pyx_capi_fnptr_to_callable (
532
+ any_usm_ndarray ,
533
+ "UsmNDArray_MakeFromMemory" ,
534
+ b"PyObject *(int, Py_ssize_t const *, int, "
535
+ b"struct Py_MemoryObject *, Py_ssize_t)" ,
536
+ fn_restype = ctypes .py_object ,
537
+ fn_argtypes = (
538
+ ctypes .c_int ,
539
+ ctypes .POINTER (ctypes .c_ssize_t ),
540
+ ctypes .c_int ,
541
+ ctypes .py_object ,
542
+ ctypes .c_ssize_t ,
543
+ ),
544
+ )
545
+ r = make_from_memory_fn (
546
+ ctypes .c_int (2 ),
547
+ c_tuple ,
548
+ ctypes .c_int (typenum ),
549
+ mem ,
550
+ ctypes .c_ssize_t (0 ),
551
+ )
552
+ assert isinstance (r , dpt .usm_ndarray )
553
+ assert r .ndim == 2
554
+ assert r .shape == (n0 , n1 )
555
+ assert r ._pointer == mem ._pointer
556
+ assert r .usm_type == "shared"
557
+ assert r .sycl_queue == q
558
+
559
+
560
+ def test_pyx_capi_set_writable_flag ():
561
+ q = get_queue_or_skip ()
562
+ usm_ndarray = dpt .empty ((4 , 5 ), dtype = "i4" , sycl_queue = q )
563
+ assert isinstance (usm_ndarray , dpt .usm_ndarray )
564
+ assert usm_ndarray .flags ["WRITABLE" ] is True
565
+ set_writable = _pyx_capi_fnptr_to_callable (
566
+ usm_ndarray ,
567
+ "UsmNDArray_SetWritableFlag" ,
568
+ b"void (struct PyUSMArrayObject *, int)" ,
569
+ fn_restype = None ,
570
+ fn_argtypes = (ctypes .py_object , ctypes .c_int ),
571
+ )
572
+ set_writable (usm_ndarray , ctypes .c_int (0 ))
573
+ assert isinstance (usm_ndarray , dpt .usm_ndarray )
574
+ assert usm_ndarray .flags ["WRITABLE" ] is False
575
+ set_writable (usm_ndarray , ctypes .c_int (1 ))
576
+ assert isinstance (usm_ndarray , dpt .usm_ndarray )
577
+ assert usm_ndarray .flags ["WRITABLE" ] is True
578
+
579
+
580
+ def test_pyx_capi_make_from_ptr ():
581
+ q = get_queue_or_skip ()
582
+ usm_ndarray = dpt .empty (tuple (), dtype = "i4" , sycl_queue = q )
583
+ make_from_ptr = _pyx_capi_fnptr_to_callable (
584
+ usm_ndarray ,
585
+ "UsmNDArray_MakeFromPtr" ,
586
+ b"PyObject *(size_t, int, DPCTLSyclUSMRef, "
587
+ b"DPCTLSyclQueueRef, PyObject *)" ,
588
+ fn_restype = ctypes .py_object ,
589
+ fn_argtypes = (
590
+ ctypes .c_size_t ,
591
+ ctypes .c_int ,
592
+ ctypes .c_void_p ,
593
+ ctypes .c_void_p ,
594
+ ctypes .py_object ,
595
+ ),
596
+ )
597
+ nelems = 10
598
+ dt = dpt .int64
599
+ mem = dpm .MemoryUSMDevice (nelems * dt .itemsize , queue = q )
600
+ arr = make_from_ptr (
601
+ ctypes .c_size_t (nelems ),
602
+ dt .num ,
603
+ mem ._pointer ,
604
+ mem .sycl_queue .addressof_ref (),
605
+ mem ,
606
+ )
607
+ assert isinstance (arr , dpt .usm_ndarray )
608
+ assert arr .shape == (nelems ,)
609
+ assert arr .dtype == dt
610
+ assert arr .sycl_queue == q
611
+ assert arr ._pointer == mem ._pointer
612
+ del mem
613
+ assert isinstance (arr .__repr__ (), str )
614
+
615
+
511
616
def _pyx_capi_int (X , pyx_capi_name , caps_name = b"int" , val_restype = ctypes .c_int ):
512
617
import sys
513
618
0 commit comments