Skip to content

Commit fc7041e

Browse files
Streamlined Chebyshev solver
Expose cpp_cg_solve used in standalone_cpp executable from Python. Invoked that from Python script sycl_timing_solver.py ```bash $ python sycl_timing_solver.py 1000 11 Solving 1000 by 1000 diagonal linear system with rank 11 perturbation. Name Intel(R) UHD Graphics [0x9bca] Driver version 1.3.22992 Vendor Intel(R) Corporation Profile FULL_PROFILE Filter string level_zero:gpu:0 Using not in-order queue 0 (host_dt, device_dt)= (1157.4030127376318, 403.9605020000001) 1 (host_dt, device_dt)= (421.32044583559036, 403.45619400000004) 2 (host_dt, device_dt)= (420.66121101379395, 402.57058400000005) 3 (host_dt, device_dt)= (421.5433243662119, 402.9254920000001) 4 (host_dt, device_dt)= (421.9988752156496, 402.8818340000001) 5 (host_dt, device_dt)= (422.3589450120926, 402.63814600000006) Converged in: [11, 11, 11, 11, 11, 11] Python solution residual norm squared: 3.2839902926527995e-25 0 (host_dt, device_dt)= (412.9443597048521, 403.6290000000001) 1 (host_dt, device_dt)= (413.7023724615574, 403.8434720000001) 2 (host_dt, device_dt)= (413.4188834577799, 403.1985620000001) 3 (host_dt, device_dt)= (413.85203413665295, 402.70404800000006) 4 (host_dt, device_dt)= (416.2806496024132, 404.0513040000001) 5 (host_dt, device_dt)= (417.43320040404797, 404.74999800000006) Converged in: [11, 11, 11, 11, 11, 11] cpp_cg_solve solution residual norm squared: 3.218393087932091e-25 ```
1 parent 3467eeb commit fc7041e

File tree

4 files changed

+154
-20
lines changed

4 files changed

+154
-20
lines changed

examples/pybind11/onemkl_gemv/solve.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,44 +38,44 @@ def chebyshev(A, b, x0, nIters, lMax, lMin, depends=[]):
3838
p = empty_like(Ax)
3939

4040
e_x = dpctl.SyclEvent()
41-
he_dot, e_dot = sycl_gemm.gemv(
42-
exec_queue, A, x, Ax, depends=depends
43-
) # Ax = A @ x
44-
he_sub, e_sub = sycl_gemm.sub(
45-
exec_queue, b, Ax, r, depends=[e_dot]
46-
) # r = b - Ax
41+
# Ax = A @ x
42+
_, e_dot = sycl_gemm.gemv(exec_queue, A, x, Ax, depends=depends)
43+
# r = b - Ax
44+
_, e_sub = sycl_gemm.sub(exec_queue, b, Ax, r, depends=[e_dot])
4745
r_ev = e_sub
4846
for i in range(nIters):
4947
z = r
5048
z_ev = r_ev
5149
if i == 0:
5250
p[:] = z
5351
alpha = 1 / d
54-
he_axpby, e_axpby = dpctl.SyclEvent(), dpctl.SyclEvent()
52+
_, e_axpby = dpctl.SyclEvent(), dpctl.SyclEvent()
5553
elif i == 1:
5654
beta = 0.5 * (c * alpha) ** 2
5755
alpha = 1 / (d - beta / alpha)
58-
he_axpby, e_axpby = sycl_gemm.axpby_inplace(
56+
# p = z + beta * p
57+
_, e_axpby = sycl_gemm.axpby_inplace(
5958
exec_queue, 1, z, beta, p, depends=[z_ev]
60-
) # p = z + beta * p
59+
)
6160
else:
6261
beta = (c / 2 * alpha) ** 2
6362
alpha = 1 / (d - beta / alpha)
64-
he_axpby, e_axpby = sycl_gemm.axpby_inplace(
63+
# p = z + beta * p
64+
_, e_axpby = sycl_gemm.axpby_inplace(
6565
exec_queue, 1, z, beta, p, depends=[z_ev]
66-
) # p = z + beta * p
67-
h_x, e_x = sycl_gemm.axpby_inplace(
66+
)
67+
# x = x + alpha * p
68+
_, e_x = sycl_gemm.axpby_inplace(
6869
exec_queue, alpha, p, 1, x, depends=[e_axpby, e_x]
69-
) # x = x + alpha * p
70-
he_dot, e_dot = sycl_gemm.gemv(
71-
exec_queue, A, x, Ax, depends=[e_x]
72-
) # Ax = A @ x
73-
he_sub, e_sub = sycl_gemm.sub(
74-
exec_queue, b, Ax, r, depends=[e_dot]
75-
) # r = b - Ax
70+
)
71+
# Ax = A @ x
72+
_, e_dot = sycl_gemm.gemv(exec_queue, A, x, Ax, depends=[e_x])
73+
# r = b - Ax
74+
_, e_sub = sycl_gemm.sub(exec_queue, b, Ax, r, depends=[e_dot])
75+
# residual = dot(r, r)
7676
residual = sycl_gemm.norm_squared_blocking(
7777
exec_queue, r, depends=[e_sub]
78-
) # residual = dot(r, r)
78+
)
7979
if residual <= 1e-29:
8080
print(f"chebyshev: converged in {i} iters")
8181
break

examples/pybind11/onemkl_gemv/sycl_gemm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from ._onemkl import (
1818
axpby_inplace,
19+
cpp_cg_solve,
1920
dot_blocking,
2021
gemv,
2122
norm_squared_blocking,
@@ -28,4 +29,5 @@
2829
"axpby_inplace",
2930
"norm_squared_blocking",
3031
"dot_blocking",
32+
"cpp_cg_solve",
3133
]

examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@ py_gemv(sycl::queue q,
6262
throw std::runtime_error("Inconsistent shapes.");
6363
}
6464

65+
auto q_ctx = q.get_context();
66+
if (q_ctx != matrix.get_queue().get_context() ||
67+
q_ctx != vector.get_queue().get_context() ||
68+
q_ctx != result.get_queue().get_context())
69+
{
70+
throw std::runtime_error(
71+
"USM allocation is not bound to the context in execution queue.");
72+
}
73+
6574
int mat_flags = matrix.get_flags();
6675
int v_flags = vector.get_flags();
6776
int r_flags = result.get_flags();
@@ -176,6 +185,14 @@ py_sub(sycl::queue q,
176185
throw std::runtime_error("Vectors must have the same length");
177186
}
178187

188+
if (q.get_context() != in_v1.get_queue().get_context() ||
189+
q.get_context() != in_v2.get_queue().get_context() ||
190+
q.get_context() != out_r.get_queue().get_context())
191+
{
192+
throw std::runtime_error(
193+
"USM allocation is not bound to the context in execution queue");
194+
}
195+
179196
int in_v1_flags = in_v1.get_flags();
180197
int in_v2_flags = in_v2.get_flags();
181198
int out_r_flags = out_r.get_flags();
@@ -277,6 +294,13 @@ py_axpby_inplace(sycl::queue q,
277294
throw std::runtime_error("Vectors must have the same length");
278295
}
279296

297+
if (q.get_context() != x.get_queue().get_context() ||
298+
q.get_context() != y.get_queue().get_context())
299+
{
300+
throw std::runtime_error(
301+
"USM allocation is not bound to the context in execution queue");
302+
}
303+
280304
int x_flags = x.get_flags();
281305
int y_flags = y.get_flags();
282306

@@ -373,6 +397,11 @@ py::object py_norm_squared_blocking(sycl::queue q,
373397
throw std::runtime_error("Vector must be contiguous.");
374398
}
375399

400+
if (q.get_context() != r.get_queue().get_context()) {
401+
throw std::runtime_error(
402+
"USM allocation is not bound to the context in execution queue");
403+
}
404+
376405
int r_typenum = r.get_typenum();
377406
if ((r_typenum != UAR_DOUBLE) && (r_typenum != UAR_FLOAT) &&
378407
(r_typenum != UAR_CDOUBLE) && (r_typenum != UAR_CFLOAT))
@@ -437,6 +466,13 @@ py::object py_dot_blocking(sycl::queue q,
437466
throw std::runtime_error("Vectors must be contiguous.");
438467
}
439468

469+
if (q.get_context() != v1.get_queue().get_context() ||
470+
q.get_context() != v2.get_queue().get_context())
471+
{
472+
throw std::runtime_error(
473+
"USM allocation is not bound to the context in execution queue");
474+
}
475+
440476
int v1_typenum = v1.get_typenum();
441477
int v2_typenum = v2.get_typenum();
442478

@@ -500,6 +536,80 @@ py::object py_dot_blocking(sycl::queue q,
500536
return res;
501537
}
502538

539+
int py_cg_solve(sycl::queue exec_q,
540+
dpctl::tensor::usm_ndarray Amat,
541+
dpctl::tensor::usm_ndarray bvec,
542+
dpctl::tensor::usm_ndarray xvec,
543+
double rs_tol,
544+
const std::vector<sycl::event> &depends = {})
545+
{
546+
if (Amat.get_ndim() != 2 || bvec.get_ndim() != 1 || xvec.get_ndim() != 1) {
547+
throw py::value_error("Expecting a matrix and two vectors");
548+
}
549+
550+
py::ssize_t n0 = Amat.get_shape(0);
551+
py::ssize_t n1 = Amat.get_shape(1);
552+
553+
if (n0 != n1) {
554+
throw py::value_error("Matrix must be square.");
555+
}
556+
557+
if (n0 != bvec.get_shape(0) || n0 != xvec.get_shape(0)) {
558+
throw py::value_error(
559+
"Dimensions of the matrix and vectors are not consistent.");
560+
}
561+
562+
bool all_contig = (Amat.get_flags() & USM_ARRAY_C_CONTIGUOUS) &&
563+
(bvec.get_flags() & USM_ARRAY_C_CONTIGUOUS) &&
564+
(xvec.get_flags() & USM_ARRAY_C_CONTIGUOUS);
565+
if (!all_contig) {
566+
throw py::value_error("All inputs must be C-contiguous");
567+
}
568+
569+
int A_typenum = Amat.get_typenum();
570+
int b_typenum = bvec.get_typenum();
571+
int x_typenum = xvec.get_typenum();
572+
573+
if (A_typenum != b_typenum || A_typenum != x_typenum) {
574+
throw py::value_error("All arrays must have the same type");
575+
}
576+
577+
if (exec_q.get_context() != Amat.get_queue().get_context() ||
578+
exec_q.get_context() != bvec.get_queue().get_context() ||
579+
exec_q.get_context() != xvec.get_queue().get_context())
580+
{
581+
throw std::runtime_error(
582+
"USM allocations are not bound to context in execution queue");
583+
}
584+
585+
const char *A_ch = Amat.get_data();
586+
const char *b_ch = bvec.get_data();
587+
char *x_ch = xvec.get_data();
588+
589+
if (A_typenum == UAR_DOUBLE) {
590+
using T = double;
591+
int iters = cg_solver::cg_solve<T>(
592+
exec_q, n0, reinterpret_cast<const T *>(A_ch),
593+
reinterpret_cast<const T *>(b_ch), reinterpret_cast<T *>(x_ch),
594+
depends, static_cast<T>(rs_tol));
595+
596+
return iters;
597+
}
598+
else if (A_typenum == UAR_FLOAT) {
599+
using T = float;
600+
int iters = cg_solver::cg_solve<T>(
601+
exec_q, n0, reinterpret_cast<const T *>(A_ch),
602+
reinterpret_cast<const T *>(b_ch), reinterpret_cast<T *>(x_ch),
603+
depends, static_cast<T>(rs_tol));
604+
605+
return iters;
606+
}
607+
else {
608+
throw std::runtime_error(
609+
"Unsupported data type. Use single or double precision.");
610+
}
611+
}
612+
503613
PYBIND11_MODULE(_onemkl, m)
504614
{
505615
// Import the dpctl extensions
@@ -518,4 +628,10 @@ PYBIND11_MODULE(_onemkl, m)
518628
py::arg("exec_queue"), py::arg("r"), py::arg("depends") = py::list());
519629
m.def("dot_blocking", &py_dot_blocking, "<v1, v2>", py::arg("exec_queue"),
520630
py::arg("v1"), py::arg("v2"), py::arg("depends") = py::list());
631+
632+
m.def("cpp_cg_solve", &py_cg_solve,
633+
"Dispatch to call C++ implementation of cg_solve",
634+
py::arg("exec_queue"), py::arg("Amat"), py::arg("bvec"),
635+
py::arg("xvec"), py::arg("rs_squared_tolerance") = py::float_(1e-20),
636+
py::arg("depends") = py::list());
521637
}

examples/pybind11/onemkl_gemv/sycl_timing_solver.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,19 @@
7474
rs = sycl_gemm.norm_squared_blocking(q, delta)
7575
dpctl.SyclEvent.wait_for([hev, hev2])
7676
print(f"Python solution residual norm squared: {rs}")
77+
78+
x_cpp = dpt.empty_like(b)
79+
iters = []
80+
for i in range(6):
81+
with timer(api_dev.sycl_queue):
82+
conv_in = sycl_gemm.cpp_cg_solve(q, A, b, x_cpp)
83+
84+
print(i, "(host_dt, device_dt)=", timer.dt)
85+
iters.append(conv_in)
86+
87+
print("Converged in: ", iters)
88+
hev, ev = sycl_gemm.gemv(q, A, x_cpp, r)
89+
hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev])
90+
rs = sycl_gemm.norm_squared_blocking(q, delta)
91+
dpctl.SyclEvent.wait_for([hev, hev2])
92+
print(f"cpp_cg_solve solution residual norm squared: {rs}")

0 commit comments

Comments
 (0)