Skip to content

Commit dd99abf

Browse files
Parametrize ocl kernel compiled from source in types of arguments
1 parent d785b69 commit dd99abf

File tree

1 file changed

+27
-12
lines changed

1 file changed

+27
-12
lines changed

dpctl/tests/test_sycl_kernel_submit.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,42 @@
2727
import dpctl.program as dpctl_prog
2828

2929

30-
def test_create_program_from_source():
30+
@pytest.mark.parametrize(
31+
"ctype_str,dtype,ctypes_ctor",
32+
[
33+
("short", np.dtype("i2"), ctypes.c_short),
34+
("int", np.dtype("i4"), ctypes.c_int),
35+
("unsigned int", np.dtype("u4"), ctypes.c_uint),
36+
("long", np.dtype(np.longlong), ctypes.c_long),
37+
("unsigned long", np.dtype(np.ulonglong), ctypes.c_ulong),
38+
("float", np.dtype("f4"), ctypes.c_float),
39+
("double", np.dtype("f8"), ctypes.c_double),
40+
],
41+
)
42+
def test_create_program_from_source(ctype_str, dtype, ctypes_ctor):
3143
try:
3244
q = dpctl.SyclQueue("opencl", property="enable_profiling")
3345
except dpctl.SyclQueueCreationError:
3446
pytest.skip("OpenCL queue could not be created")
35-
oclSrc = " \
36-
kernel void axpy(global int* a, global int* b, global int* c, int d) { \
37-
size_t index = get_global_id(0); \
38-
c[index] = d*a[index] + b[index]; \
39-
}"
47+
oclSrc = (
48+
"kernel void axpy("
49+
" global " + ctype_str + " *a, global " + ctype_str + " *b,"
50+
" global " + ctype_str + " *c, " + ctype_str + " d) {"
51+
" size_t index = get_global_id(0);"
52+
" c[index] = d * a[index] + b[index];"
53+
"}"
54+
)
4055
prog = dpctl_prog.create_program_from_source(q, oclSrc)
4156
axpyKernel = prog.get_sycl_kernel("axpy")
4257

4358
n_elems = 1024 * 512
44-
bufBytes = n_elems * np.dtype("i").itemsize
59+
bufBytes = n_elems * dtype.itemsize
4560
abuf = dpctl_mem.MemoryUSMShared(bufBytes, queue=q)
4661
bbuf = dpctl_mem.MemoryUSMShared(bufBytes, queue=q)
4762
cbuf = dpctl_mem.MemoryUSMShared(bufBytes, queue=q)
48-
a = np.ndarray((n_elems,), buffer=abuf, dtype="i")
49-
b = np.ndarray((n_elems,), buffer=bbuf, dtype="i")
50-
c = np.ndarray((n_elems,), buffer=cbuf, dtype="i")
63+
a = np.ndarray((n_elems,), buffer=abuf, dtype=dtype)
64+
b = np.ndarray((n_elems,), buffer=bbuf, dtype=dtype)
65+
c = np.ndarray((n_elems,), buffer=cbuf, dtype=dtype)
5166
a[:] = np.arange(n_elems)
5267
b[:] = np.arange(n_elems, 0, -1)
5368
c[:] = 0
@@ -57,7 +72,7 @@ def test_create_program_from_source():
5772
args.append(a.base)
5873
args.append(b.base)
5974
args.append(c.base)
60-
args.append(ctypes.c_int(d))
75+
args.append(ctypes_ctor(d))
6176

6277
r = [
6378
n_elems,
@@ -66,7 +81,7 @@ def test_create_program_from_source():
6681
timer = dpctl.SyclTimer()
6782
with timer(q):
6883
q.submit(axpyKernel, args, r)
69-
ref_c = a * d + b
84+
ref_c = a * np.array(d, dtype=dtype) + b
7085
host_dt, device_dt = timer.dt
7186
assert host_dt > device_dt
7287
assert np.allclose(c, ref_c)

0 commit comments

Comments
 (0)