Skip to content

Commit 50a3243

Browse files
Added few assertions, moved creation of common buffers to the top to enable commenting out sections of the code for performance measurements
1 parent 93e50a8 commit 50a3243

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

examples/pybind11/onemkl_gemv/sycl_timing_solver.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@
5555
A = dpt.asarray(Anp, "d", device=api_dev)
5656
b = dpt.asarray(bnp, "d", device=api_dev)
5757

58+
assert A.sycl_queue == b.sycl_queue
59+
60+
# allocate buffers for computation of residual
61+
r = dpt.empty_like(b)
62+
delta = dpt.empty_like(b)
63+
5864
timer = dpctl.SyclTimer(time_scale=1e3)
5965

6066
iters = []
@@ -64,17 +70,22 @@
6470

6571
print(i, "(host_dt, device_dt)=", timer.dt)
6672
iters.append(conv_in)
73+
assert x.usm_type == A.usm_type
74+
assert x.usm_type == b.usm_type
75+
assert x.sycl_queue == A.sycl_queue
76+
assert x.sycl_queue == b.sycl_queue
6777

6878
print("Converged in: ", iters)
6979

70-
r = dpt.empty_like(b)
7180
hev, ev = sycl_gemm.gemv(q, A, x, r)
72-
delta = dpt.empty_like(b)
7381
hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev])
7482
rs = sycl_gemm.norm_squared_blocking(q, delta)
7583
dpctl.SyclEvent.wait_for([hev, hev2])
7684
print(f"Python solution residual norm squared: {rs}")
7785

86+
assert q == api_dev.sycl_queue
87+
print("")
88+
7889
x_cpp = dpt.empty_like(b)
7990
iters = []
8091
for i in range(6):

0 commit comments

Comments
 (0)