21
21
import dpctl .tensor as dpt
22
22
23
23
24
- def empty_like (A ):
25
- return dpt .empty (A .shape , A .dtype , device = A .device )
26
-
27
-
28
24
def chebyshev (A , b , x0 , nIters , lMax , lMin , depends = []):
29
25
"""Chebyshev iterative solver using SYCL routines"""
30
26
d = (lMax + lMin ) / 2
@@ -33,9 +29,9 @@ def chebyshev(A, b, x0, nIters, lMax, lMin, depends=[]):
33
29
x = dpt .copy (x0 )
34
30
exec_queue = A .sycl_queue
35
31
assert exec_queue == x .sycl_queue
36
- Ax = empty_like (A [:, 0 ])
37
- r = empty_like (Ax )
38
- p = empty_like (Ax )
32
+ Ax = dpt . empty_like (A [:, 0 ])
33
+ r = dpt . empty_like (Ax )
34
+ p = dpt . empty_like (Ax )
39
35
40
36
e_x = dpctl .SyclEvent ()
41
37
# Ax = A @ x
@@ -131,12 +127,13 @@ def cg_solve(A, b):
131
127
converged is False if solver has not converged, or the iteration number
132
128
"""
133
129
exec_queue = A .sycl_queue
134
- x = dpt .zeros ( b . shape , dtype = b . dtype )
135
- Ap = empty_like (x )
130
+ x = dpt .zeros_like ( b )
131
+ Ap = dpt . empty_like (x )
136
132
137
133
all_host_tasks = []
138
- r = dpt .copy (b )
139
- p = dpt .copy (b )
134
+ r = dpt .copy (b ) # synchronous copy
135
+ p = dpt .copy (b ) # synchronous copy
136
+
140
137
rsold = sycl_gemm .norm_squared_blocking (exec_queue , r )
141
138
if rsold < 1e-20 :
142
139
return (b , 0 )
@@ -147,22 +144,21 @@ def cg_solve(A, b):
147
144
e_x = dpctl .SyclEvent ()
148
145
for i in range (max_iters ):
149
146
# Ap = A @ p
150
- he_dot , e_dot = sycl_gemm .gemv (exec_queue , A , p , Ap , depends = [e_p ])
151
- all_host_tasks .append (he_dot )
147
+ he_gemv , e_gemv = sycl_gemm .gemv (exec_queue , A , p , Ap , depends = [e_p ])
148
+ all_host_tasks .append (he_gemv )
152
149
# alpha = rsold / dot(p, Ap)
153
150
alpha = rsold / sycl_gemm .dot_blocking (
154
- exec_queue , p , Ap , depends = [e_dot ]
151
+ exec_queue , p , Ap , depends = [e_p , e_gemv ]
155
152
)
156
153
# x = x + alpha * p
157
154
he1_x_update , e1_x_update = sycl_gemm .axpby_inplace (
158
- exec_queue , alpha , p , 1 , x , depends = [e_p , e_x ]
155
+ exec_queue , alpha , p , 1 , x , depends = [e_x ]
159
156
)
160
157
all_host_tasks .append (he1_x_update )
161
- e_x = e1_x_update
162
158
163
159
# r = r - alpha * Ap
164
160
he2_r_update , e2_r_update = sycl_gemm .axpby_inplace (
165
- exec_queue , - alpha , Ap , 1 , r , depends = [ e_p ]
161
+ exec_queue , - alpha , Ap , 1 , r
166
162
)
167
163
all_host_tasks .append (he2_r_update )
168
164
0 commit comments