27
27
import dpctl .program as dpctl_prog
28
28
29
29
30
- def test_create_program_from_source ():
30
+ @pytest .mark .parametrize (
31
+ "ctype_str,dtype,ctypes_ctor" ,
32
+ [
33
+ ("short" , np .dtype ("i2" ), ctypes .c_short ),
34
+ ("int" , np .dtype ("i4" ), ctypes .c_int ),
35
+ ("unsigned int" , np .dtype ("u4" ), ctypes .c_uint ),
36
+ ("long" , np .dtype (np .longlong ), ctypes .c_long ),
37
+ ("unsigned long" , np .dtype (np .ulonglong ), ctypes .c_ulong ),
38
+ ("float" , np .dtype ("f4" ), ctypes .c_float ),
39
+ ("double" , np .dtype ("f8" ), ctypes .c_double ),
40
+ ],
41
+ )
42
+ def test_create_program_from_source (ctype_str , dtype , ctypes_ctor ):
31
43
try :
32
44
q = dpctl .SyclQueue ("opencl" , property = "enable_profiling" )
33
45
except dpctl .SyclQueueCreationError :
34
46
pytest .skip ("OpenCL queue could not be created" )
35
- oclSrc = " \
36
- kernel void axpy(global int* a, global int* b, global int* c, int d) { \
37
- size_t index = get_global_id(0); \
38
- c[index] = d*a[index] + b[index]; \
39
- }"
47
+ oclSrc = (
48
+ "kernel void axpy("
49
+ " global " + ctype_str + " *a, global " + ctype_str + " *b,"
50
+ " global " + ctype_str + " *c, " + ctype_str + " d) {"
51
+ " size_t index = get_global_id(0);"
52
+ " c[index] = d * a[index] + b[index];"
53
+ "}"
54
+ )
40
55
prog = dpctl_prog .create_program_from_source (q , oclSrc )
41
56
axpyKernel = prog .get_sycl_kernel ("axpy" )
42
57
43
58
n_elems = 1024 * 512
44
- bufBytes = n_elems * np . dtype ( "i" ) .itemsize
59
+ bufBytes = n_elems * dtype .itemsize
45
60
abuf = dpctl_mem .MemoryUSMShared (bufBytes , queue = q )
46
61
bbuf = dpctl_mem .MemoryUSMShared (bufBytes , queue = q )
47
62
cbuf = dpctl_mem .MemoryUSMShared (bufBytes , queue = q )
48
- a = np .ndarray ((n_elems ,), buffer = abuf , dtype = "i" )
49
- b = np .ndarray ((n_elems ,), buffer = bbuf , dtype = "i" )
50
- c = np .ndarray ((n_elems ,), buffer = cbuf , dtype = "i" )
63
+ a = np .ndarray ((n_elems ,), buffer = abuf , dtype = dtype )
64
+ b = np .ndarray ((n_elems ,), buffer = bbuf , dtype = dtype )
65
+ c = np .ndarray ((n_elems ,), buffer = cbuf , dtype = dtype )
51
66
a [:] = np .arange (n_elems )
52
67
b [:] = np .arange (n_elems , 0 , - 1 )
53
68
c [:] = 0
@@ -57,7 +72,7 @@ def test_create_program_from_source():
57
72
args .append (a .base )
58
73
args .append (b .base )
59
74
args .append (c .base )
60
- args .append (ctypes . c_int (d ))
75
+ args .append (ctypes_ctor (d ))
61
76
62
77
r = [
63
78
n_elems ,
@@ -66,7 +81,7 @@ def test_create_program_from_source():
66
81
timer = dpctl .SyclTimer ()
67
82
with timer (q ):
68
83
q .submit (axpyKernel , args , r )
69
- ref_c = a * d + b
84
+ ref_c = a * np . array ( d , dtype = dtype ) + b
70
85
host_dt , device_dt = timer .dt
71
86
assert host_dt > device_dt
72
87
assert np .allclose (c , ref_c )
0 commit comments