Skip to content

Refine gemv example #833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions examples/pybind11/onemkl_gemv/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@
import dpctl.tensor as dpt


def empty_like(A):
return dpt.empty(A.shape, A.dtype, device=A.device)


def chebyshev(A, b, x0, nIters, lMax, lMin, depends=[]):
"""Chebyshev iterative solver using SYCL routines"""
d = (lMax + lMin) / 2
Expand All @@ -33,9 +29,9 @@ def chebyshev(A, b, x0, nIters, lMax, lMin, depends=[]):
x = dpt.copy(x0)
exec_queue = A.sycl_queue
assert exec_queue == x.sycl_queue
Ax = empty_like(A[:, 0])
r = empty_like(Ax)
p = empty_like(Ax)
Ax = dpt.empty_like(A[:, 0])
r = dpt.empty_like(Ax)
p = dpt.empty_like(Ax)

e_x = dpctl.SyclEvent()
# Ax = A @ x
Expand Down Expand Up @@ -131,12 +127,13 @@ def cg_solve(A, b):
converged is False if solver has not converged, or the iteration number
"""
exec_queue = A.sycl_queue
x = dpt.zeros(b.shape, dtype=b.dtype)
Ap = empty_like(x)
x = dpt.zeros_like(b)
Ap = dpt.empty_like(x)

all_host_tasks = []
r = dpt.copy(b)
p = dpt.copy(b)
r = dpt.copy(b) # synchronous copy
p = dpt.copy(b) # synchronous copy

rsold = sycl_gemm.norm_squared_blocking(exec_queue, r)
if rsold < 1e-20:
return (b, 0)
Expand All @@ -147,22 +144,21 @@ def cg_solve(A, b):
e_x = dpctl.SyclEvent()
for i in range(max_iters):
# Ap = A @ p
he_dot, e_dot = sycl_gemm.gemv(exec_queue, A, p, Ap, depends=[e_p])
all_host_tasks.append(he_dot)
he_gemv, e_gemv = sycl_gemm.gemv(exec_queue, A, p, Ap, depends=[e_p])
all_host_tasks.append(he_gemv)
# alpha = rsold / dot(p, Ap)
alpha = rsold / sycl_gemm.dot_blocking(
exec_queue, p, Ap, depends=[e_dot]
exec_queue, p, Ap, depends=[e_p, e_gemv]
)
# x = x + alpha * p
he1_x_update, e1_x_update = sycl_gemm.axpby_inplace(
exec_queue, alpha, p, 1, x, depends=[e_p, e_x]
exec_queue, alpha, p, 1, x, depends=[e_x]
)
all_host_tasks.append(he1_x_update)
e_x = e1_x_update

# r = r - alpha * Ap
he2_r_update, e2_r_update = sycl_gemm.axpby_inplace(
exec_queue, -alpha, Ap, 1, r, depends=[e_p]
exec_queue, -alpha, Ap, 1, r
)
all_host_tasks.append(he2_r_update)

Expand Down
4 changes: 4 additions & 0 deletions examples/pybind11/onemkl_gemv/sycl_gemm/_onemkl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ py::object py_dot_blocking(sycl::queue q,
reinterpret_cast<const T *>(v2_typeless_ptr), 1, res_usm, depends);
T res_v{};
q.copy<T>(res_usm, &res_v, 1, {dot_ev}).wait_and_throw();
sycl::free(res_usm, q);
res = py::float_(res_v);
}
else if (v1_typenum == UAR_FLOAT) {
Expand All @@ -507,6 +508,7 @@ py::object py_dot_blocking(sycl::queue q,
reinterpret_cast<const T *>(v2_typeless_ptr), 1, res_usm, depends);
T res_v(0);
q.copy<T>(res_usm, &res_v, 1, {dot_ev}).wait_and_throw();
sycl::free(res_usm, q);
res = py::float_(res_v);
}
else if (v1_typenum == UAR_CDOUBLE) {
Expand All @@ -517,6 +519,7 @@ py::object py_dot_blocking(sycl::queue q,
reinterpret_cast<const T *>(v2_typeless_ptr), 1, res_usm, depends);
T res_v{};
q.copy<T>(res_usm, &res_v, 1, {dotc_ev}).wait_and_throw();
sycl::free(res_usm, q);
res = py::cast(res_v);
}
else if (v1_typenum == UAR_CFLOAT) {
Expand All @@ -527,6 +530,7 @@ py::object py_dot_blocking(sycl::queue q,
reinterpret_cast<const T *>(v2_typeless_ptr), 1, res_usm, depends);
T res_v{};
q.copy<T>(res_usm, &res_v, 1, {dotc_ev}).wait_and_throw();
sycl::free(res_usm, q);
res = py::cast(res_v);
}
else {
Expand Down
9 changes: 4 additions & 5 deletions examples/pybind11/onemkl_gemv/sycl_gemm/cg_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,16 @@ int cg_solve(sycl::queue exec_q,
}

int converged_at = max_iters;
sycl::event prev_dep = copy_to_p_ev;
sycl::event e_p = copy_to_p_ev;
sycl::event e_x = fill_ev;

for (std::int64_t i = 0; i < max_iters; ++i) {
sycl::event gemv_ev = oneapi::mkl::blas::row_major::gemv(
exec_q, oneapi::mkl::transpose::N, n, n, T(1), Amat, n, p, 1, T(0),
Ap, 1, {prev_dep});
Ap, 1, {e_p});

sycl::event pAp_dot_ev = oneapi::mkl::blas::row_major::dot(
exec_q, n, p, 1, Ap, 1, pAp_dot_dev, {prev_dep, gemv_ev});
exec_q, n, p, 1, Ap, 1, pAp_dot_dev, {e_p, gemv_ev});

T pAp_dot_host{};
exec_q.copy<T>(pAp_dot_dev, &pAp_dot_host, 1, {pAp_dot_ev})
Expand All @@ -212,8 +212,7 @@ int cg_solve(sycl::queue exec_q,
T beta = rs_new / rs_old;

// p = r + beta * p
prev_dep =
detail::axpby_inplace(exec_q, n, T(1), r, beta, p, {r_update_ev});
e_p = detail::axpby_inplace(exec_q, n, T(1), r, beta, p, {r_update_ev});
e_x = x_update_ev;

rs_old = rs_new;
Expand Down
15 changes: 13 additions & 2 deletions examples/pybind11/onemkl_gemv/sycl_timing_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@
A = dpt.asarray(Anp, "d", device=api_dev)
b = dpt.asarray(bnp, "d", device=api_dev)

assert A.sycl_queue == b.sycl_queue

# allocate buffers for computation of residual
r = dpt.empty_like(b)
delta = dpt.empty_like(b)

timer = dpctl.SyclTimer(time_scale=1e3)

iters = []
Expand All @@ -64,17 +70,22 @@

print(i, "(host_dt, device_dt)=", timer.dt)
iters.append(conv_in)
assert x.usm_type == A.usm_type
assert x.usm_type == b.usm_type
assert x.sycl_queue == A.sycl_queue
assert x.sycl_queue == b.sycl_queue

print("Converged in: ", iters)

r = dpt.empty_like(b)
hev, ev = sycl_gemm.gemv(q, A, x, r)
delta = dpt.empty_like(b)
hev2, ev2 = sycl_gemm.sub(q, r, b, delta, [ev])
rs = sycl_gemm.norm_squared_blocking(q, delta)
dpctl.SyclEvent.wait_for([hev, hev2])
print(f"Python solution residual norm squared: {rs}")

assert q == api_dev.sycl_queue
print("")

x_cpp = dpt.empty_like(b)
iters = []
for i in range(6):
Expand Down