Skip to content

Parametrize ocl kernel compiled from source in types of arguments #581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 63 additions & 25 deletions dpctl/tests/test_sycl_kernel_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,46 +27,84 @@
import dpctl.program as dpctl_prog


def test_create_program_from_source():
@pytest.mark.parametrize(
"ctype_str,dtype,ctypes_ctor",
[
("short", np.dtype("i2"), ctypes.c_short),
("int", np.dtype("i4"), ctypes.c_int),
("unsigned int", np.dtype("u4"), ctypes.c_uint),
("long", np.dtype(np.longlong), ctypes.c_longlong),
("unsigned long", np.dtype(np.ulonglong), ctypes.c_ulonglong),
("float", np.dtype("f4"), ctypes.c_float),
("double", np.dtype("f8"), ctypes.c_double),
],
)
def test_create_program_from_source(ctype_str, dtype, ctypes_ctor):
try:
q = dpctl.SyclQueue("opencl", property="enable_profiling")
except dpctl.SyclQueueCreationError:
pytest.skip("OpenCL queue could not be created")
oclSrc = " \
kernel void axpy(global int* a, global int* b, global int* c, int d) { \
size_t index = get_global_id(0); \
c[index] = d*a[index] + b[index]; \
}"
# OpenCL conventions for indexing global_id is opposite to
# that of SYCL (and DPCTL)
oclSrc = (
"kernel void axpy("
" global " + ctype_str + " *a, global " + ctype_str + " *b,"
" global " + ctype_str + " *c, " + ctype_str + " d) {"
" size_t index = get_global_id(0);"
" c[index] = d * a[index] + b[index];"
"}"
)
prog = dpctl_prog.create_program_from_source(q, oclSrc)
axpyKernel = prog.get_sycl_kernel("axpy")

n_elems = 1024 * 512
bufBytes = n_elems * np.dtype("i").itemsize
lws = 128
bufBytes = n_elems * dtype.itemsize
abuf = dpctl_mem.MemoryUSMShared(bufBytes, queue=q)
bbuf = dpctl_mem.MemoryUSMShared(bufBytes, queue=q)
cbuf = dpctl_mem.MemoryUSMShared(bufBytes, queue=q)
a = np.ndarray((n_elems,), buffer=abuf, dtype="i")
b = np.ndarray((n_elems,), buffer=bbuf, dtype="i")
c = np.ndarray((n_elems,), buffer=cbuf, dtype="i")
a = np.ndarray((n_elems,), buffer=abuf, dtype=dtype)
b = np.ndarray((n_elems,), buffer=bbuf, dtype=dtype)
c = np.ndarray((n_elems,), buffer=cbuf, dtype=dtype)
a[:] = np.arange(n_elems)
b[:] = np.arange(n_elems, 0, -1)
c[:] = 0
d = 2
args = []
args = [a.base, b.base, c.base, ctypes_ctor(d)]

args.append(a.base)
args.append(b.base)
args.append(c.base)
args.append(ctypes.c_int(d))
assert n_elems % lws == 0

r = [
n_elems,
]
for r in (
[
n_elems,
],
[2, n_elems],
[2, 2, n_elems],
):
c[:] = 0
timer = dpctl.SyclTimer()
with timer(q):
q.submit(axpyKernel, args, r).wait()
ref_c = a * np.array(d, dtype=dtype) + b
host_dt, device_dt = timer.dt
assert host_dt > device_dt
assert np.allclose(c, ref_c), "Failed for {}".format(r)

timer = dpctl.SyclTimer()
with timer(q):
q.submit(axpyKernel, args, r)
ref_c = a * d + b
host_dt, device_dt = timer.dt
assert host_dt > device_dt
assert np.allclose(c, ref_c)
for gr, lr in (
(
[
n_elems,
],
[lws],
),
([2, n_elems], [2, lws // 2]),
([2, 2, n_elems], [2, 2, lws // 4]),
):
c[:] = 0
timer = dpctl.SyclTimer()
with timer(q):
q.submit(axpyKernel, args, gr, lr, [dpctl.SyclEvent()]).wait()
ref_c = a * np.array(d, dtype=dtype) + b
host_dt, device_dt = timer.dt
assert host_dt > device_dt
assert np.allclose(c, ref_c), "Faled for {}, {}".formatg(r, lr)
1 change: 1 addition & 0 deletions dpctl/tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ def test_queue_capsule():
q2 = dpctl.SyclQueue(cap)
assert q == q2
del cap2 # call deleter on non-renamed capsule
assert q2 != [] # compare with other types


def test_cpython_api():
Expand Down