|
27 | 27 | import dpctl.program as dpctl_prog
|
28 | 28 |
|
29 | 29 |
|
30 |
| -def test_create_program_from_source(): |
| 30 | +@pytest.mark.parametrize( |
| 31 | + "ctype_str,dtype,ctypes_ctor", |
| 32 | + [ |
| 33 | + ("short", np.dtype("i2"), ctypes.c_short), |
| 34 | + ("int", np.dtype("i4"), ctypes.c_int), |
| 35 | + ("unsigned int", np.dtype("u4"), ctypes.c_uint), |
| 36 | + ("long", np.dtype(np.longlong), ctypes.c_longlong), |
| 37 | + ("unsigned long", np.dtype(np.ulonglong), ctypes.c_ulonglong), |
| 38 | + ("float", np.dtype("f4"), ctypes.c_float), |
| 39 | + ("double", np.dtype("f8"), ctypes.c_double), |
| 40 | + ], |
| 41 | +) |
| 42 | +def test_create_program_from_source(ctype_str, dtype, ctypes_ctor): |
31 | 43 | try:
|
32 | 44 | q = dpctl.SyclQueue("opencl", property="enable_profiling")
|
33 | 45 | except dpctl.SyclQueueCreationError:
|
34 | 46 | pytest.skip("OpenCL queue could not be created")
|
35 |
| - oclSrc = " \ |
36 |
| - kernel void axpy(global int* a, global int* b, global int* c, int d) { \ |
37 |
| - size_t index = get_global_id(0); \ |
38 |
| - c[index] = d*a[index] + b[index]; \ |
39 |
| - }" |
| 47 | + # OpenCL conventions for indexing global_id is opposite to |
| 48 | + # that of SYCL (and DPCTL) |
| 49 | + oclSrc = ( |
| 50 | + "kernel void axpy(" |
| 51 | + " global " + ctype_str + " *a, global " + ctype_str + " *b," |
| 52 | + " global " + ctype_str + " *c, " + ctype_str + " d) {" |
| 53 | + " size_t index = get_global_id(0);" |
| 54 | + " c[index] = d * a[index] + b[index];" |
| 55 | + "}" |
| 56 | + ) |
40 | 57 | prog = dpctl_prog.create_program_from_source(q, oclSrc)
|
41 | 58 | axpyKernel = prog.get_sycl_kernel("axpy")
|
42 | 59 |
|
43 | 60 | n_elems = 1024 * 512
|
44 |
| - bufBytes = n_elems * np.dtype("i").itemsize |
| 61 | + lws = 128 |
| 62 | + bufBytes = n_elems * dtype.itemsize |
45 | 63 | abuf = dpctl_mem.MemoryUSMShared(bufBytes, queue=q)
|
46 | 64 | bbuf = dpctl_mem.MemoryUSMShared(bufBytes, queue=q)
|
47 | 65 | cbuf = dpctl_mem.MemoryUSMShared(bufBytes, queue=q)
|
48 |
| - a = np.ndarray((n_elems,), buffer=abuf, dtype="i") |
49 |
| - b = np.ndarray((n_elems,), buffer=bbuf, dtype="i") |
50 |
| - c = np.ndarray((n_elems,), buffer=cbuf, dtype="i") |
| 66 | + a = np.ndarray((n_elems,), buffer=abuf, dtype=dtype) |
| 67 | + b = np.ndarray((n_elems,), buffer=bbuf, dtype=dtype) |
| 68 | + c = np.ndarray((n_elems,), buffer=cbuf, dtype=dtype) |
51 | 69 | a[:] = np.arange(n_elems)
|
52 | 70 | b[:] = np.arange(n_elems, 0, -1)
|
53 | 71 | c[:] = 0
|
54 | 72 | d = 2
|
55 |
| - args = [] |
| 73 | + args = [a.base, b.base, c.base, ctypes_ctor(d)] |
56 | 74 |
|
57 |
| - args.append(a.base) |
58 |
| - args.append(b.base) |
59 |
| - args.append(c.base) |
60 |
| - args.append(ctypes.c_int(d)) |
| 75 | + assert n_elems % lws == 0 |
61 | 76 |
|
62 |
| - r = [ |
63 |
| - n_elems, |
64 |
| - ] |
| 77 | + for r in ( |
| 78 | + [ |
| 79 | + n_elems, |
| 80 | + ], |
| 81 | + [2, n_elems], |
| 82 | + [2, 2, n_elems], |
| 83 | + ): |
| 84 | + c[:] = 0 |
| 85 | + timer = dpctl.SyclTimer() |
| 86 | + with timer(q): |
| 87 | + q.submit(axpyKernel, args, r).wait() |
| 88 | + ref_c = a * np.array(d, dtype=dtype) + b |
| 89 | + host_dt, device_dt = timer.dt |
| 90 | + assert host_dt > device_dt |
| 91 | + assert np.allclose(c, ref_c), "Failed for {}".format(r) |
65 | 92 |
|
66 |
| - timer = dpctl.SyclTimer() |
67 |
| - with timer(q): |
68 |
| - q.submit(axpyKernel, args, r) |
69 |
| - ref_c = a * d + b |
70 |
| - host_dt, device_dt = timer.dt |
71 |
| - assert host_dt > device_dt |
72 |
| - assert np.allclose(c, ref_c) |
| 93 | + for gr, lr in ( |
| 94 | + ( |
| 95 | + [ |
| 96 | + n_elems, |
| 97 | + ], |
| 98 | + [lws], |
| 99 | + ), |
| 100 | + ([2, n_elems], [2, lws // 2]), |
| 101 | + ([2, 2, n_elems], [2, 2, lws // 4]), |
| 102 | + ): |
| 103 | + c[:] = 0 |
| 104 | + timer = dpctl.SyclTimer() |
| 105 | + with timer(q): |
| 106 | + q.submit(axpyKernel, args, gr, lr, [dpctl.SyclEvent()]).wait() |
| 107 | + ref_c = a * np.array(d, dtype=dtype) + b |
| 108 | + host_dt, device_dt = timer.dt |
| 109 | + assert host_dt > device_dt |
| 110 | + assert np.allclose(c, ref_c), "Faled for {}, {}".formatg(r, lr) |
0 commit comments