Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update dpnp.linalg.matrix_rank() implementation #1717

Merged
merged 9 commits into from
Feb 19, 2024
9 changes: 0 additions & 9 deletions dpnp/backend/kernels/dpnp_krnl_linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,15 +579,6 @@ template <typename _DataType>
void (*dpnp_matrix_rank_default_c)(void *, void *, shape_elem_type *, size_t) =
dpnp_matrix_rank_c<_DataType>;

template <typename _DataType>
DPCTLSyclEventRef (*dpnp_matrix_rank_ext_c)(DPCTLSyclQueueRef,
void *,
void *,
shape_elem_type *,
size_t,
const DPCTLEventVectorRef) =
dpnp_matrix_rank_c<_DataType>;

template <typename _InputDT, typename _ComputeDT>
DPCTLSyclEventRef dpnp_qr_c(DPCTLSyclQueueRef q_ref,
void *array1_in,
Expand Down
4 changes: 4 additions & 0 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,10 @@ def matrix_rank(A, tol=None, hermitian=False):
rank : (...) dpnp.ndarray
Rank of A.

See Also
--------
:obj:`dpnp.linalg.svd` : Singular Value Decomposition.

Examples
--------
>>> import dpnp as np
Expand Down
16 changes: 4 additions & 12 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,20 +1014,12 @@ def dpnp_matrix_rank(A, tol=None, hermitian=False):
S = dpnp_svd(A, compute_uv=False, hermitian=hermitian)

if tol is None:
tol = (
S.max(axis=-1, keepdims=True)
* max(A.shape[-2:])
* dpnp.finfo(S.dtype).eps
)
rtol = max(A.shape[-2:]) * dpnp.finfo(S.dtype).eps
tol = S.max(axis=-1, keepdims=True) * rtol
else:
if dpnp.is_supported_array_type(tol):
# Check that `a` and `tol` are allocated on the same device
# and have the same queue. Otherwise, `ValueError`` will be raised.
get_usm_allocations([A, tol])
else:
# Allocate dpnp.ndarray if tol is a scalar
tol = dpnp.array(tol, usm_type=A.usm_type, sycl_queue=A.sycl_queue)
tol = tol[..., None]
# Add a new axis to match Numpy's output
tol = tol[..., None]
return dpnp.count_nonzero(S > tol, axis=-1)


Expand Down
11 changes: 8 additions & 3 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def test_inv_errors(self):


class TestMatrixRank:
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
@pytest.mark.parametrize("dtype", get_all_dtypes())
@pytest.mark.parametrize(
"data",
[
Expand All @@ -586,7 +586,7 @@ def test_matrix_rank(self, data, dtype):
dp_rank = inp.linalg.matrix_rank(a_dp)
assert np_rank == dp_rank

@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
@pytest.mark.parametrize("dtype", get_all_dtypes())
@pytest.mark.parametrize(
"data",
[
Expand Down Expand Up @@ -666,7 +666,12 @@ def test_matrix_rank_errors(self):
tol_queue = dpctl.SyclQueue()
a_dp_q = inp.array(a_dp, sycl_queue=a_queue)
tol_dp_q = inp.array([0.5], dtype="float32", sycl_queue=tol_queue)
assert_raises(ValueError, inp.linalg.matrix_rank, a_dp_q, tol_dp_q)
assert_raises(
dpctl.utils._compute_follows_data.ExecutionPlacementError,
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
inp.linalg.matrix_rank,
a_dp_q,
tol_dp_q,
)


@pytest.mark.usefixtures("allow_fall_back_on_numpy")
Expand Down