|
63 | 63 |
|
64 | 64 | timer = dpctl.SyclTimer(time_scale=1e3)
|
65 | 65 |
|
66 |
| -iters = [] |
67 |
| -for i in range(6): |
68 |
| - with timer(api_dev.sycl_queue): |
69 |
| - x, conv_in = solve.cg_solve(A, b) |
70 |
| - |
71 |
| - print(i, "(host_dt, device_dt)=", timer.dt) |
72 |
| - iters.append(conv_in) |
73 |
| - assert x.usm_type == A.usm_type |
74 |
| - assert x.usm_type == b.usm_type |
75 |
| - assert x.sycl_queue == A.sycl_queue |
76 |
| - assert x.sycl_queue == b.sycl_queue |
77 |
| - |
78 |
| -print("Converged in: ", iters) |
79 |
| - |
80 |
| -hev, ev = sycl_gemm.gemv(q, A, x, r) |
81 |
| -hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev]) |
82 |
| -rs = sycl_gemm.norm_squared_blocking(q, delta) |
83 |
| -dpctl.SyclEvent.wait_for([hev, hev2]) |
84 |
| -print(f"Python solution residual norm squared: {rs}") |
| 66 | + |
| 67 | +def time_python_solver(num_iters=6): |
| 68 | + """ |
| 69 | + Time solver implemented in Python with use of asynchronous |
| 70 | + SYCL kernel submission. |
| 71 | + """ |
| 72 | + global x |
| 73 | + iters = [] |
| 74 | + for i in range(num_iters): |
| 75 | + with timer(api_dev.sycl_queue): |
| 76 | + x, conv_in = solve.cg_solve(A, b) |
| 77 | + |
| 78 | + print(i, "(host_dt, device_dt)=", timer.dt) |
| 79 | + iters.append(conv_in) |
| 80 | + assert x.usm_type == A.usm_type |
| 81 | + assert x.usm_type == b.usm_type |
| 82 | + assert x.sycl_queue == A.sycl_queue |
| 83 | + assert x.sycl_queue == b.sycl_queue |
| 84 | + |
| 85 | + return iters |
| 86 | + |
| 87 | + |
| 88 | +def time_cpp_solver(num_iters=6): |
| 89 | + """ |
| 90 | + Time solver implemented in C++ but callable from Python. |
| 91 | + C++ implementation uses the same algorithm and submits same |
| 92 | + kernels asynchronously, but bypasses Python binding overhead |
| 93 | + incurred when algorithm is driver from Python. |
| 94 | + """ |
| 95 | + global x_cpp |
| 96 | + x_cpp = dpt.empty_like(b) |
| 97 | + iters = [] |
| 98 | + for i in range(num_iters): |
| 99 | + with timer(api_dev.sycl_queue): |
| 100 | + conv_in = sycl_gemm.cpp_cg_solve(q, A, b, x_cpp) |
| 101 | + |
| 102 | + print(i, "(host_dt, device_dt)=", timer.dt) |
| 103 | + iters.append(conv_in) |
| 104 | + |
| 105 | + return iters |
| 106 | + |
| 107 | + |
| 108 | +def compute_residual(x): |
| 109 | + """ |
| 110 | + Computes quality of the solution, `norm_squared(A@x - b)`. |
| 111 | + """ |
| 112 | + assert isinstance(x, dpt.usm_ndarray) |
| 113 | + q = A.sycl_queue |
| 114 | + hev, ev = sycl_gemm.gemv(q, A, x, r) |
| 115 | + hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev]) |
| 116 | + rs = sycl_gemm.norm_squared_blocking(q, delta) |
| 117 | + dpctl.SyclEvent.wait_for([hev, hev2]) |
| 118 | + return rs |
| 119 | + |
| 120 | + |
| 121 | +print("Converged in: ", time_python_solver()) |
| 122 | +print(f"Python solution residual norm squared: {compute_residual(x)}") |
85 | 123 |
|
86 | 124 | assert q == api_dev.sycl_queue
|
87 | 125 | print("")
|
88 | 126 |
|
89 |
| -x_cpp = dpt.empty_like(b) |
90 |
| -iters = [] |
91 |
| -for i in range(6): |
92 |
| - with timer(api_dev.sycl_queue): |
93 |
| - conv_in = sycl_gemm.cpp_cg_solve(q, A, b, x_cpp) |
94 |
| - |
95 |
| - print(i, "(host_dt, device_dt)=", timer.dt) |
96 |
| - iters.append(conv_in) |
97 |
| - |
98 |
| -print("Converged in: ", iters) |
99 |
| -hev, ev = sycl_gemm.gemv(q, A, x_cpp, r) |
100 |
| -hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev]) |
101 |
| -rs = sycl_gemm.norm_squared_blocking(q, delta) |
102 |
| -dpctl.SyclEvent.wait_for([hev, hev2]) |
103 |
| -print(f"cpp_cg_solve solution residual norm squared: {rs}") |
| 127 | +print("Converged in: ", time_cpp_solver()) |
| 128 | +print(f"cpp_cg_solve solution residual norm squared: {compute_residual(x_cpp)}") |
0 commit comments