Skip to content

Commit d1ec4d0

Browse files
Align task dependency with what C++ does
Use dpctl.tensor.empty_like, zeros_like instead of custom-defined function.
1 parent 0141fd5 commit d1ec4d0

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

examples/pybind11/onemkl_gemv/solve.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@
2121
import dpctl.tensor as dpt
2222

2323

24-
def empty_like(A):
25-
return dpt.empty(A.shape, A.dtype, device=A.device)
26-
27-
2824
def chebyshev(A, b, x0, nIters, lMax, lMin, depends=[]):
2925
"""Chebyshev iterative solver using SYCL routines"""
3026
d = (lMax + lMin) / 2
@@ -33,9 +29,9 @@ def chebyshev(A, b, x0, nIters, lMax, lMin, depends=[]):
3329
x = dpt.copy(x0)
3430
exec_queue = A.sycl_queue
3531
assert exec_queue == x.sycl_queue
36-
Ax = empty_like(A[:, 0])
37-
r = empty_like(Ax)
38-
p = empty_like(Ax)
32+
Ax = dpt.empty_like(A[:, 0])
33+
r = dpt.empty_like(Ax)
34+
p = dpt.empty_like(Ax)
3935

4036
e_x = dpctl.SyclEvent()
4137
# Ax = A @ x
@@ -131,12 +127,13 @@ def cg_solve(A, b):
131127
converged is False if solver has not converged, or the iteration number
132128
"""
133129
exec_queue = A.sycl_queue
134-
x = dpt.zeros(b.shape, dtype=b.dtype)
135-
Ap = empty_like(x)
130+
x = dpt.zeros_like(b)
131+
Ap = dpt.empty_like(x)
136132

137133
all_host_tasks = []
138-
r = dpt.copy(b)
139-
p = dpt.copy(b)
134+
r = dpt.copy(b) # synchronous copy
135+
p = dpt.copy(b) # synchronous copy
136+
140137
rsold = sycl_gemm.norm_squared_blocking(exec_queue, r)
141138
if rsold < 1e-20:
142139
return (b, 0)
@@ -147,22 +144,21 @@ def cg_solve(A, b):
147144
e_x = dpctl.SyclEvent()
148145
for i in range(max_iters):
149146
# Ap = A @ p
150-
he_dot, e_dot = sycl_gemm.gemv(exec_queue, A, p, Ap, depends=[e_p])
151-
all_host_tasks.append(he_dot)
147+
he_gemv, e_gemv = sycl_gemm.gemv(exec_queue, A, p, Ap, depends=[e_p])
148+
all_host_tasks.append(he_gemv)
152149
# alpha = rsold / dot(p, Ap)
153150
alpha = rsold / sycl_gemm.dot_blocking(
154-
exec_queue, p, Ap, depends=[e_dot]
151+
exec_queue, p, Ap, depends=[e_p, e_gemv]
155152
)
156153
# x = x + alpha * p
157154
he1_x_update, e1_x_update = sycl_gemm.axpby_inplace(
158-
exec_queue, alpha, p, 1, x, depends=[e_p, e_x]
155+
exec_queue, alpha, p, 1, x, depends=[e_x]
159156
)
160157
all_host_tasks.append(he1_x_update)
161-
e_x = e1_x_update
162158

163159
# r = r - alpha * Ap
164160
he2_r_update, e2_r_update = sycl_gemm.axpby_inplace(
165-
exec_queue, -alpha, Ap, 1, r, depends=[e_p]
161+
exec_queue, -alpha, Ap, 1, r
166162
)
167163
all_host_tasks.append(he2_r_update)
168164

0 commit comments

Comments
 (0)