Skip to content

Commit 50be964

Browse files
Modularized the examples/pybind11/onemkl_gemv/sycl_timing_solver.py (#838)
1 parent f94ef5e commit 50be964

File tree

1 file changed

+59
-34
lines changed

1 file changed

+59
-34
lines changed

examples/pybind11/onemkl_gemv/sycl_timing_solver.py

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -63,41 +63,66 @@
6363

6464
timer = dpctl.SyclTimer(time_scale=1e3)
6565

66-
iters = []
67-
for i in range(6):
68-
with timer(api_dev.sycl_queue):
69-
x, conv_in = solve.cg_solve(A, b)
70-
71-
print(i, "(host_dt, device_dt)=", timer.dt)
72-
iters.append(conv_in)
73-
assert x.usm_type == A.usm_type
74-
assert x.usm_type == b.usm_type
75-
assert x.sycl_queue == A.sycl_queue
76-
assert x.sycl_queue == b.sycl_queue
77-
78-
print("Converged in: ", iters)
79-
80-
hev, ev = sycl_gemm.gemv(q, A, x, r)
81-
hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev])
82-
rs = sycl_gemm.norm_squared_blocking(q, delta)
83-
dpctl.SyclEvent.wait_for([hev, hev2])
84-
print(f"Python solution residual norm squared: {rs}")
66+
67+
def time_python_solver(num_iters=6):
68+
"""
69+
Time solver implemented in Python with use of asynchronous
70+
SYCL kernel submission.
71+
"""
72+
global x
73+
iters = []
74+
for i in range(num_iters):
75+
with timer(api_dev.sycl_queue):
76+
x, conv_in = solve.cg_solve(A, b)
77+
78+
print(i, "(host_dt, device_dt)=", timer.dt)
79+
iters.append(conv_in)
80+
assert x.usm_type == A.usm_type
81+
assert x.usm_type == b.usm_type
82+
assert x.sycl_queue == A.sycl_queue
83+
assert x.sycl_queue == b.sycl_queue
84+
85+
return iters
86+
87+
88+
def time_cpp_solver(num_iters=6):
89+
"""
90+
Time solver implemented in C++ but callable from Python.
91+
C++ implementation uses the same algorithm and submits same
92+
kernels asynchronously, but bypasses Python binding overhead
93+
incurred when algorithm is driver from Python.
94+
"""
95+
global x_cpp
96+
x_cpp = dpt.empty_like(b)
97+
iters = []
98+
for i in range(num_iters):
99+
with timer(api_dev.sycl_queue):
100+
conv_in = sycl_gemm.cpp_cg_solve(q, A, b, x_cpp)
101+
102+
print(i, "(host_dt, device_dt)=", timer.dt)
103+
iters.append(conv_in)
104+
105+
return iters
106+
107+
108+
def compute_residual(x):
109+
"""
110+
Computes quality of the solution, `norm_squared(A@x - b)`.
111+
"""
112+
assert isinstance(x, dpt.usm_ndarray)
113+
q = A.sycl_queue
114+
hev, ev = sycl_gemm.gemv(q, A, x, r)
115+
hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev])
116+
rs = sycl_gemm.norm_squared_blocking(q, delta)
117+
dpctl.SyclEvent.wait_for([hev, hev2])
118+
return rs
119+
120+
121+
print("Converged in: ", time_python_solver())
122+
print(f"Python solution residual norm squared: {compute_residual(x)}")
85123

86124
assert q == api_dev.sycl_queue
87125
print("")
88126

89-
x_cpp = dpt.empty_like(b)
90-
iters = []
91-
for i in range(6):
92-
with timer(api_dev.sycl_queue):
93-
conv_in = sycl_gemm.cpp_cg_solve(q, A, b, x_cpp)
94-
95-
print(i, "(host_dt, device_dt)=", timer.dt)
96-
iters.append(conv_in)
97-
98-
print("Converged in: ", iters)
99-
hev, ev = sycl_gemm.gemv(q, A, x_cpp, r)
100-
hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev])
101-
rs = sycl_gemm.norm_squared_blocking(q, delta)
102-
dpctl.SyclEvent.wait_for([hev, hev2])
103-
print(f"cpp_cg_solve solution residual norm squared: {rs}")
127+
print("Converged in: ", time_cpp_solver())
128+
print(f"cpp_cg_solve solution residual norm squared: {compute_residual(x_cpp)}")

0 commit comments

Comments
 (0)