Skip to content

Commit 5f155bc

Browse files
Merge pull request #581 from IntelPython/parametrize-ocl-kernel-arg-types
Parametrize ocl kernel compiled from source in types of arguments
2 parents d785b69 + 4e5765a commit 5f155bc

File tree

2 files changed

+64
-25
lines changed

2 files changed

+64
-25
lines changed

dpctl/tests/test_sycl_kernel_submit.py

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,46 +27,84 @@
2727
import dpctl.program as dpctl_prog
2828

2929

30-
def test_create_program_from_source():
30+
@pytest.mark.parametrize(
31+
"ctype_str,dtype,ctypes_ctor",
32+
[
33+
("short", np.dtype("i2"), ctypes.c_short),
34+
("int", np.dtype("i4"), ctypes.c_int),
35+
("unsigned int", np.dtype("u4"), ctypes.c_uint),
36+
("long", np.dtype(np.longlong), ctypes.c_longlong),
37+
("unsigned long", np.dtype(np.ulonglong), ctypes.c_ulonglong),
38+
("float", np.dtype("f4"), ctypes.c_float),
39+
("double", np.dtype("f8"), ctypes.c_double),
40+
],
41+
)
42+
def test_create_program_from_source(ctype_str, dtype, ctypes_ctor):
3143
try:
3244
q = dpctl.SyclQueue("opencl", property="enable_profiling")
3345
except dpctl.SyclQueueCreationError:
3446
pytest.skip("OpenCL queue could not be created")
35-
oclSrc = " \
36-
kernel void axpy(global int* a, global int* b, global int* c, int d) { \
37-
size_t index = get_global_id(0); \
38-
c[index] = d*a[index] + b[index]; \
39-
}"
47+
# OpenCL conventions for indexing global_id is opposite to
48+
# that of SYCL (and DPCTL)
49+
oclSrc = (
50+
"kernel void axpy("
51+
" global " + ctype_str + " *a, global " + ctype_str + " *b,"
52+
" global " + ctype_str + " *c, " + ctype_str + " d) {"
53+
" size_t index = get_global_id(0);"
54+
" c[index] = d * a[index] + b[index];"
55+
"}"
56+
)
4057
prog = dpctl_prog.create_program_from_source(q, oclSrc)
4158
axpyKernel = prog.get_sycl_kernel("axpy")
4259

4360
n_elems = 1024 * 512
44-
bufBytes = n_elems * np.dtype("i").itemsize
61+
lws = 128
62+
bufBytes = n_elems * dtype.itemsize
4563
abuf = dpctl_mem.MemoryUSMShared(bufBytes, queue=q)
4664
bbuf = dpctl_mem.MemoryUSMShared(bufBytes, queue=q)
4765
cbuf = dpctl_mem.MemoryUSMShared(bufBytes, queue=q)
48-
a = np.ndarray((n_elems,), buffer=abuf, dtype="i")
49-
b = np.ndarray((n_elems,), buffer=bbuf, dtype="i")
50-
c = np.ndarray((n_elems,), buffer=cbuf, dtype="i")
66+
a = np.ndarray((n_elems,), buffer=abuf, dtype=dtype)
67+
b = np.ndarray((n_elems,), buffer=bbuf, dtype=dtype)
68+
c = np.ndarray((n_elems,), buffer=cbuf, dtype=dtype)
5169
a[:] = np.arange(n_elems)
5270
b[:] = np.arange(n_elems, 0, -1)
5371
c[:] = 0
5472
d = 2
55-
args = []
73+
args = [a.base, b.base, c.base, ctypes_ctor(d)]
5674

57-
args.append(a.base)
58-
args.append(b.base)
59-
args.append(c.base)
60-
args.append(ctypes.c_int(d))
75+
assert n_elems % lws == 0
6176

62-
r = [
63-
n_elems,
64-
]
77+
for r in (
78+
[
79+
n_elems,
80+
],
81+
[2, n_elems],
82+
[2, 2, n_elems],
83+
):
84+
c[:] = 0
85+
timer = dpctl.SyclTimer()
86+
with timer(q):
87+
q.submit(axpyKernel, args, r).wait()
88+
ref_c = a * np.array(d, dtype=dtype) + b
89+
host_dt, device_dt = timer.dt
90+
assert host_dt > device_dt
91+
assert np.allclose(c, ref_c), "Failed for {}".format(r)
6592

66-
timer = dpctl.SyclTimer()
67-
with timer(q):
68-
q.submit(axpyKernel, args, r)
69-
ref_c = a * d + b
70-
host_dt, device_dt = timer.dt
71-
assert host_dt > device_dt
72-
assert np.allclose(c, ref_c)
93+
for gr, lr in (
94+
(
95+
[
96+
n_elems,
97+
],
98+
[lws],
99+
),
100+
([2, n_elems], [2, lws // 2]),
101+
([2, 2, n_elems], [2, 2, lws // 4]),
102+
):
103+
c[:] = 0
104+
timer = dpctl.SyclTimer()
105+
with timer(q):
106+
q.submit(axpyKernel, args, gr, lr, [dpctl.SyclEvent()]).wait()
107+
ref_c = a * np.array(d, dtype=dtype) + b
108+
host_dt, device_dt = timer.dt
109+
assert host_dt > device_dt
110+
assert np.allclose(c, ref_c), "Faled for {}, {}".formatg(r, lr)

dpctl/tests/test_sycl_queue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@ def test_queue_capsule():
464464
q2 = dpctl.SyclQueue(cap)
465465
assert q == q2
466466
del cap2 # call deleter on non-renamed capsule
467+
assert q2 != [] # compare with other types
467468

468469

469470
def test_cpython_api():

0 commit comments

Comments
 (0)