Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update dpnp.linalg.inv() function #1665

Merged
merged 17 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions dpnp/backend/extensions/lapack/getrf_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ static sycl::event getrf_batch_impl(sycl::queue exec_q,
// Get the indices of the first zero diagonal elements of these matrices
auto error_info = be.exceptions();

auto error_matrices_ids_size = error_matrices_ids.size();
auto dev_info_size = static_cast<std::size_t>(py::len(dev_info));
if (error_matrices_ids_size != dev_info_size) {
throw py::value_error("The size of `dev_info` must be equal to" +
std::to_string(error_matrices_ids_size) +
", but currently it is " +
std::to_string(dev_info_size) + ".");
}

for (size_t i = 0; i < error_matrices_ids.size(); ++i) {
// Assign the index of the first zero diagonal element in each
// error matrix to the corresponding index in 'dev_info'
Expand Down Expand Up @@ -190,6 +199,14 @@ std::pair<sycl::event, sycl::event>
", but a 2-dimensional array is expected.");
}

const int dev_info_size = py::len(dev_info);
if (dev_info_size != batch_size) {
throw py::value_error("The size of 'dev_info' (" +
std::to_string(dev_info_size) +
") does not match the expected batch size (" +
std::to_string(batch_size) + ").");
}

// check compatibility of execution queue and allocation queue
if (!dpctl::utils::queues_are_compatible(exec_q, {a_array, ipiv_array})) {
throw py::value_error(
Expand Down
17 changes: 17 additions & 0 deletions dpnp/backend/extensions/lapack/getri_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,15 @@ static sycl::event getri_batch_impl(sycl::queue exec_q,
// Get the indices of the first zero diagonal elements of these matrices
auto error_info = be.exceptions();

auto error_matrices_ids_size = error_matrices_ids.size();
auto dev_info_size = static_cast<std::size_t>(py::len(dev_info));
if (error_matrices_ids_size != dev_info_size) {
throw py::value_error("The size of `dev_info` must be equal to" +
std::to_string(error_matrices_ids_size) +
", but currently it is " +
std::to_string(dev_info_size) + ".");
}

for (size_t i = 0; i < error_matrices_ids.size(); ++i) {
// Assign the index of the first zero diagonal element in each
// error matrix to the corresponding index in 'dev_info'
Expand Down Expand Up @@ -188,6 +197,14 @@ std::pair<sycl::event, sycl::event>
", but a 2-dimensional array is expected.");
}

const int dev_info_size = py::len(dev_info);
if (dev_info_size != batch_size) {
throw py::value_error("The size of 'dev_info' (" +
std::to_string(dev_info_size) +
") does not match the expected batch size (" +
std::to_string(batch_size) + ").");
}

// check compatibility of execution queue and allocation queue
if (!dpctl::utils::queues_are_compatible(exec_q, {a_array, ipiv_array})) {
throw py::value_error(
Expand Down
9 changes: 0 additions & 9 deletions dpnp/backend/kernels/dpnp_krnl_linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,15 +378,6 @@ template <typename _DataType, typename _ResultType>
void (*dpnp_inv_default_c)(void *, void *, shape_elem_type *, size_t) =
dpnp_inv_c<_DataType, _ResultType>;

template <typename _DataType, typename _ResultType>
DPCTLSyclEventRef (*dpnp_inv_ext_c)(DPCTLSyclQueueRef,
void *,
void *,
shape_elem_type *,
size_t,
const DPCTLEventVectorRef) =
dpnp_inv_c<_DataType, _ResultType>;

template <typename _DataType1, typename _DataType2, typename _ResultType>
class dpnp_kron_c_kernel;

Expand Down
85 changes: 40 additions & 45 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,33 @@ def _calculate_determinant_sign(ipiv, diag, res_type, n):
return sign.astype(res_type)


def _check_lapack_dev_info(dev_info, error_msg=None):
"""
Check `dev_info` from OneMKL LAPACK routines, raising an error for failures.

Parameters
----------
dev_info : list of ints
Each element of the list indicates the status of OneMKL LAPACK routine calls.
A non-zero value signifies a failure.

error_message : str, optional
Custom error message for detected LAPACK errors.
Default: `Singular matrix`

Raises
------
dpnp.linalg.LinAlgError
On non-zero elements in dev_info, indicating LAPACK errors.

"""

if any(dev_info):
error_msg = error_msg or "Singular matrix"

raise dpnp.linalg.LinAlgError(error_msg)


def _real_type(dtype, device=None):
"""
Returns the real data type corresponding to a given dpnp data type.
Expand Down Expand Up @@ -442,35 +469,6 @@ def _lu_factor(a, res_type):
return (a_h, ipiv_h, dev_info_array)


def check_lapack_dev_info(dev_info, error_msg=None):
"""
Check `dev_info` from oneMKL LAPACK routines, raising an error for failures.

Parameters
----------
dev_info : list
Integers indicating the status of oneMKL LAPACK routine calls. A non-zero
value signifies a failure.

error_message : str, optional
Custom error message for detected LAPACK errors.
Default: `Singular matrix`

Raises
------
dpnp.linalg.LinAlgError
On non-zero elements in dev_info, indicating LAPACK errors.

"""

dev_info_array = dpnp.array(dev_info)

if (dev_info_array != 0).any():
error_msg = error_msg or "Singular matrix"

raise dpnp.linalg.LinAlgError(error_msg)


def dpnp_cholesky_batch(a, upper_lower, res_type):
"""
dpnp_cholesky_batch(a, upper_lower, res_type)
Expand Down Expand Up @@ -790,9 +788,6 @@ def dpnp_inv_batched(a, res_type):
a_usm_type = a.usm_type
n = a.shape[1]

if 0 in orig_shape:
return dpnp.empty_like(a, dtype=res_type)

# oneMKL LAPACK getri_batch overwrites `a`
a_h = dpnp.empty_like(a, order="C", dtype=res_type, usm_type=a_usm_type)
ipiv_h = dpnp.empty(
Expand All @@ -801,8 +796,7 @@ def dpnp_inv_batched(a, res_type):
usm_type=a_usm_type,
sycl_queue=a_sycl_queue,
)
dev_info_getrf_h = [0] * batch_size
dev_info_getri_h = [0] * batch_size
dev_info = [0] * batch_size

# use DPCTL tensor function to fill the matrix array
# with content from the input array `a`
Expand All @@ -815,39 +809,40 @@ def dpnp_inv_batched(a, res_type):

# Call the LAPACK extension function _getrf_batch
# to perform LU decomposition of a batch of general matrices
ht_lapack_ev, getrf_ev = li._getrf_batch(
ht_getrf_ev, getrf_ev = li._getrf_batch(
a_sycl_queue,
a_h.get_array(),
ipiv_h.get_array(),
dev_info_getrf_h,
dev_info,
n,
a_stride,
ipiv_stride,
batch_size,
[a_copy_ev],
)

_check_lapack_dev_info(dev_info)

# Call the LAPACK extension function _getri_batch
# to compute the inverse of a batch of matrices using the results
# from the LU decomposition performed by _getrf_batch
ht_lapack_ev_1, _ = li._getri_batch(
ht_getri_ev, _ = li._getri_batch(
a_sycl_queue,
a_h.get_array(),
ipiv_h.get_array(),
dev_info_getri_h,
dev_info,
n,
a_stride,
ipiv_stride,
batch_size,
[getrf_ev],
)

ht_lapack_ev_1.wait()
ht_lapack_ev.wait()
a_ht_copy_ev.wait()
_check_lapack_dev_info(dev_info)

check_lapack_dev_info(dev_info_getrf_h)
check_lapack_dev_info(dev_info_getri_h)
ht_getrf_ev.wait()
ht_getri_ev.wait()
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
a_ht_copy_ev.wait()

return a_h.reshape(orig_shape)

Expand All @@ -866,7 +861,7 @@ def dpnp_inv(a):

res_type = _common_type(a)
if a.size == 0:
return dpnp.empty_like(a, dtype=res_type, usm_type=a.usm_type)
return dpnp.empty_like(a, dtype=res_type)

if a.ndim >= 3:
return dpnp_inv_batched(a, res_type)
Expand All @@ -880,7 +875,7 @@ def dpnp_inv(a):

# oneMKL LAPACK gesv overwrites `a` and assumes fortran-like array as input.
# Allocate 'F' order memory for dpnp arrays to comply with these requirements.
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
a_f = dpnp.empty_like(a, order=a_order, dtype=res_type, usm_type=a_usm_type)
a_f = dpnp.empty_like(a, order=a_order, dtype=res_type)

# use DPCTL tensor function to fill the coefficient matrix array
# with content from the input array `a`
Expand Down
14 changes: 11 additions & 3 deletions tests/third_party/cupy/linalg_tests/test_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def check_x(self, a_shape, b_shape, xp, dtype):
a_copy = a.copy()
b_copy = b.copy()
result = xp.linalg.solve(a, b)
numpy.testing.assert_array_equal(a_copy, a)
numpy.testing.assert_array_equal(b_copy, b)
testing.assert_array_equal(a_copy, a)
testing.assert_array_equal(b_copy, b)
return result

def test_solve(self):
Expand Down Expand Up @@ -110,7 +110,7 @@ def check_x(self, a_shape, dtype):
result_gpu = cupy.linalg.inv(a_gpu)

assert_dtype_allclose(result_gpu, result_cpu)
numpy.testing.assert_array_equal(a_gpu_copy, a_gpu)
testing.assert_array_equal(a_gpu_copy, a_gpu)

def check_shape(self, a_shape):
a = cupy.random.rand(*a_shape)
Expand All @@ -130,6 +130,14 @@ def test_inv(self):
self.check_x((3, 0, 0))
self.check_x((2, 0, 3, 4, 4))

vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
def test_invalid_shape(self):
self.check_shape((2, 3))
self.check_shape((4, 1))
self.check_shape((4, 3, 2))
self.check_shape((2, 4, 3))
self.check_shape((2, 0))
self.check_shape((0, 2, 3))


class TestInvInvalid(unittest.TestCase):
@testing.for_dtypes("ifdFD")
Expand Down
Loading