Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dpnp.linalg.solve() hang on CPU #1778

Merged
merged 8 commits into from
Apr 11, 2024
27 changes: 21 additions & 6 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,7 @@ def dpnp_solve(a, b):
return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type)

if a.ndim > 2:
is_cpu_device = exec_q.sycl_device.has_aspect_cpu
reshape = False
orig_shape_b = b_shape
if a.ndim > 3:
Expand All @@ -1850,22 +1851,27 @@ def dpnp_solve(a, b):
for i in range(batch_size):
# oneMKL LAPACK assumes fortran-like array as input, so
# allocate a memory with 'F' order for dpnp array of coefficient matrix
# and multiple dependent variables array
coeff_vecs[i] = dpnp.empty_like(
a[i], order="F", dtype=res_type, usm_type=res_usm_type
)
val_vecs[i] = dpnp.empty_like(
b[i], order="F", dtype=res_type, usm_type=res_usm_type
)

# use DPCTL tensor function to fill the coefficient matrix array
# and the array of multiple dependent variables with content
# from the input arrays
# with content from the input array
a_ht_copy_ev[i], a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=a_usm_arr[i],
dst=coeff_vecs[i].get_array(),
sycl_queue=a.sycl_queue,
)

# oneMKL LAPACK assumes fortran-like array as input, so
# allocate a memory with 'F' order for dpnp array of multiple
# dependent variables array
val_vecs[i] = dpnp.empty_like(
b[i], order="F", dtype=res_type, usm_type=res_usm_type
)

# use DPCTL tensor function to fill the array of multiple dependent
# variables with content from the input arrays
b_ht_copy_ev[i], b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=b_usm_arr[i],
dst=val_vecs[i].get_array(),
Expand All @@ -1882,6 +1888,15 @@ def dpnp_solve(a, b):
depends=[a_copy_ev, b_copy_ev],
)

# TODO: Remove this w/a when MKLD-17201 is solved.
# Waiting for a host task executing an OneMKL LAPACK gesv call
# on CPU causes deadlock due to serialization of all host tasks
# in the queue.
# We need to wait for each host tasks before calling _gesv to avoid deadlock.
if is_cpu_device:
ht_lapack_ev[i].wait()
b_ht_copy_ev[i].wait()

for i in range(batch_size):
ht_lapack_ev[i].wait()
b_ht_copy_ev[i].wait()
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
3 changes: 0 additions & 3 deletions tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,9 +899,6 @@ def test_eigenvalue(func, shape, usm_type):
)
def test_solve(matrix, vector, usm_type_matrix, usm_type_vector):
x = dp.array(matrix, usm_type=usm_type_matrix)
if x.ndim > 2 and x.device.sycl_device.is_cpu:
pytest.skip("SAT-6842: reported hanging in public CI")

y = dp.array(vector, usm_type=usm_type_vector)
z = dp.linalg.solve(x, y)

Expand Down
1 change: 0 additions & 1 deletion tests/third_party/cupy/linalg_tests/test_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def check_x(self, a_shape, b_shape, xp, dtype):
testing.assert_array_equal(b_copy, b)
return result

@pytest.mark.skipif(is_cpu_device(), reason="SAT-6842")
antonwolfy marked this conversation as resolved.
Show resolved Hide resolved
def test_solve(self):
self.check_x((4, 4), (4,))
self.check_x((5, 5), (5, 2))
Expand Down
Loading