From d34b8ce47fc9220350f8adf6fd5bc8358df5c5b9 Mon Sep 17 00:00:00 2001 From: vlad-perevezentsev Date: Mon, 19 Feb 2024 21:19:57 +0100 Subject: [PATCH] Update dpnp.linalg.matrix_rank() implementation (#1717) * Update dpnp.linalg.matrix_rank impl * Add cupy tests for dpnp.linalg.matrix_rank * Add dpnp tests for dpnp.linalg.matrix_rank * Remove old impl of dpnp_matrix_rank * Address remarks * Address remarks --- dpnp/backend/include/dpnp_iface_fptr.hpp | 42 +++--- dpnp/backend/kernels/dpnp_krnl_linalg.cpp | 18 --- dpnp/dpnp_algo/dpnp_algo.pxd | 2 - dpnp/linalg/dpnp_algo_linalg.pyx | 46 ------ dpnp/linalg/dpnp_iface_linalg.py | 49 ++++--- dpnp/linalg/dpnp_utils_linalg.py | 24 ++++ tests/test_linalg.py | 135 ++++++++++++++---- tests/test_sycl_queue.py | 28 +++- tests/test_usm_type.py | 21 +++ .../cupy/linalg_tests/test_norms.py | 30 ++++ 10 files changed, 253 insertions(+), 142 deletions(-) diff --git a/dpnp/backend/include/dpnp_iface_fptr.hpp b/dpnp/backend/include/dpnp_iface_fptr.hpp index e9a3458f84a..24b01f5ff11 100644 --- a/dpnp/backend/include/dpnp_iface_fptr.hpp +++ b/dpnp/backend/include/dpnp_iface_fptr.hpp @@ -178,28 +178,26 @@ enum class DPNPFuncName : size_t DPNP_FN_KRON, /**< Used in numpy.kron() impl */ DPNP_FN_KRON_EXT, /**< Used in numpy.kron() impl, requires extra parameters */ - DPNP_FN_LEFT_SHIFT, /**< Used in numpy.left_shift() impl */ - DPNP_FN_LOG, /**< Used in numpy.log() impl */ - DPNP_FN_LOG10, /**< Used in numpy.log10() impl */ - DPNP_FN_LOG2, /**< Used in numpy.log2() impl */ - DPNP_FN_LOG1P, /**< Used in numpy.log1p() impl */ - DPNP_FN_MATMUL, /**< Used in numpy.matmul() impl */ - DPNP_FN_MATRIX_RANK, /**< Used in numpy.linalg.matrix_rank() impl */ - DPNP_FN_MATRIX_RANK_EXT, /**< Used in numpy.linalg.matrix_rank() impl, - requires extra parameters */ - DPNP_FN_MAX, /**< Used in numpy.max() impl */ - DPNP_FN_MAXIMUM, /**< Used in numpy.fmax() impl */ - DPNP_FN_MAXIMUM_EXT, /**< Used in numpy.fmax() impl , requires extra - parameters */ - DPNP_FN_MEAN, /**< Used in numpy.mean() impl */ - DPNP_FN_MEDIAN, /**< Used in numpy.median() impl */ - DPNP_FN_MEDIAN_EXT, /**< Used in numpy.median() impl, requires extra - parameters */ - DPNP_FN_MIN, /**< Used in numpy.min() impl */ - DPNP_FN_MINIMUM, /**< Used in numpy.fmin() impl */ - DPNP_FN_MINIMUM_EXT, /**< Used in numpy.fmax() impl, requires extra - parameters */ - DPNP_FN_MODF, /**< Used in numpy.modf() impl */ + DPNP_FN_LEFT_SHIFT, /**< Used in numpy.left_shift() impl */ + DPNP_FN_LOG, /**< Used in numpy.log() impl */ + DPNP_FN_LOG10, /**< Used in numpy.log10() impl */ + DPNP_FN_LOG2, /**< Used in numpy.log2() impl */ + DPNP_FN_LOG1P, /**< Used in numpy.log1p() impl */ + DPNP_FN_MATMUL, /**< Used in numpy.matmul() impl */ + DPNP_FN_MATRIX_RANK, /**< Used in numpy.linalg.matrix_rank() impl */ + DPNP_FN_MAX, /**< Used in numpy.max() impl */ + DPNP_FN_MAXIMUM, /**< Used in numpy.fmax() impl */ + DPNP_FN_MAXIMUM_EXT, /**< Used in numpy.fmax() impl , requires extra + parameters */ + DPNP_FN_MEAN, /**< Used in numpy.mean() impl */ + DPNP_FN_MEDIAN, /**< Used in numpy.median() impl */ + DPNP_FN_MEDIAN_EXT, /**< Used in numpy.median() impl, requires extra + parameters */ + DPNP_FN_MIN, /**< Used in numpy.min() impl */ + DPNP_FN_MINIMUM, /**< Used in numpy.fmin() impl */ + DPNP_FN_MINIMUM_EXT, /**< Used in numpy.fmax() impl, requires extra + parameters */ + DPNP_FN_MODF, /**< Used in numpy.modf() impl */ DPNP_FN_MODF_EXT, /**< Used in numpy.modf() impl, requires extra parameters */ DPNP_FN_MULTIPLY, /**< Used in numpy.multiply() impl */ diff --git a/dpnp/backend/kernels/dpnp_krnl_linalg.cpp b/dpnp/backend/kernels/dpnp_krnl_linalg.cpp index d74c593115e..8f70ddd01e3 100644 --- a/dpnp/backend/kernels/dpnp_krnl_linalg.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_linalg.cpp @@ -579,15 +579,6 @@ template void (*dpnp_matrix_rank_default_c)(void *, void *, shape_elem_type *, size_t) = dpnp_matrix_rank_c<_DataType>; -template -DPCTLSyclEventRef (*dpnp_matrix_rank_ext_c)(DPCTLSyclQueueRef, - void *, - void *, - shape_elem_type *, - size_t, - const DPCTLEventVectorRef) = - dpnp_matrix_rank_c<_DataType>; - template DPCTLSyclEventRef dpnp_qr_c(DPCTLSyclQueueRef q_ref, void *array1_in, @@ -969,15 +960,6 @@ void func_map_init_linalg_func(func_map_t &fmap) fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK][eft_DBL][eft_DBL] = { eft_DBL, (void *)dpnp_matrix_rank_default_c}; - fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK_EXT][eft_INT][eft_INT] = { - eft_INT, (void *)dpnp_matrix_rank_ext_c}; - fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK_EXT][eft_LNG][eft_LNG] = { - eft_LNG, (void *)dpnp_matrix_rank_ext_c}; - fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK_EXT][eft_FLT][eft_FLT] = { - eft_FLT, (void *)dpnp_matrix_rank_ext_c}; - fmap[DPNPFuncName::DPNP_FN_MATRIX_RANK_EXT][eft_DBL][eft_DBL] = { - eft_DBL, (void *)dpnp_matrix_rank_ext_c}; - fmap[DPNPFuncName::DPNP_FN_QR][eft_INT][eft_INT] = { eft_DBL, (void *)dpnp_qr_default_c}; fmap[DPNPFuncName::DPNP_FN_QR][eft_LNG][eft_LNG] = { diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index 71382d38f26..3ad23b08fbe 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -78,8 +78,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na DPNP_FN_FULL_LIKE DPNP_FN_KRON DPNP_FN_KRON_EXT - DPNP_FN_MATRIX_RANK - DPNP_FN_MATRIX_RANK_EXT DPNP_FN_MAXIMUM DPNP_FN_MAXIMUM_EXT DPNP_FN_MEDIAN diff --git a/dpnp/linalg/dpnp_algo_linalg.pyx b/dpnp/linalg/dpnp_algo_linalg.pyx index 67cd5d93034..83597ad7595 100644 --- a/dpnp/linalg/dpnp_algo_linalg.pyx +++ b/dpnp/linalg/dpnp_algo_linalg.pyx @@ -48,24 +48,14 @@ __all__ = [ "dpnp_cond", "dpnp_eig", "dpnp_eigvals", - "dpnp_matrix_rank", "dpnp_norm", ] # C function pointer to the C library template functions -ctypedef c_dpctl.DPCTLSyclEventRef(*custom_linalg_1in_1out_func_ptr_t)(c_dpctl.DPCTLSyclQueueRef, - void *, void * ,shape_elem_type * , - size_t, const c_dpctl.DPCTLEventVectorRef) -ctypedef c_dpctl.DPCTLSyclEventRef(*custom_linalg_1in_1out_func_ptr_t_)(c_dpctl.DPCTLSyclQueueRef, - void * , void * , size_t * , - const c_dpctl.DPCTLEventVectorRef) ctypedef c_dpctl.DPCTLSyclEventRef(*custom_linalg_1in_1out_with_size_func_ptr_t_)(c_dpctl.DPCTLSyclQueueRef, void *, void * , size_t, const c_dpctl.DPCTLEventVectorRef) -ctypedef c_dpctl.DPCTLSyclEventRef(*custom_linalg_1in_3out_shape_t)(c_dpctl.DPCTLSyclQueueRef, - void *, void * , void * , void * , - size_t , size_t, const c_dpctl.DPCTLEventVectorRef) ctypedef c_dpctl.DPCTLSyclEventRef(*custom_linalg_2in_1out_func_ptr_t)(c_dpctl.DPCTLSyclQueueRef, void *, void * , void * , size_t, const c_dpctl.DPCTLEventVectorRef) @@ -183,42 +173,6 @@ cpdef utils.dpnp_descriptor dpnp_eigvals(utils.dpnp_descriptor input): return res_val -cpdef utils.dpnp_descriptor dpnp_matrix_rank(utils.dpnp_descriptor input): - cdef shape_type_c input_shape = input.shape - cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(input.dtype) - - cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_MATRIX_RANK_EXT, param1_type, param1_type) - - input_obj = input.get_array() - - # create result array with type given by FPTR data - cdef utils.dpnp_descriptor result = utils.create_output_descriptor((1,), - kernel_data.return_type, - None, - device=input_obj.sycl_device, - usm_type=input_obj.usm_type, - sycl_queue=input_obj.sycl_queue) - - result_sycl_queue = result.get_array().sycl_queue - - cdef c_dpctl.SyclQueue q = result_sycl_queue - cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref() - - cdef custom_linalg_1in_1out_func_ptr_t func = kernel_data.ptr - - cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, - input.get_data(), - result.get_data(), - input_shape.data(), - input.ndim, - NULL) # dep_events_ref - - with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref) - c_dpctl.DPCTLEvent_Delete(event_ref) - - return result - - cpdef object dpnp_norm(object input, ord=None, axis=None): cdef long size_input = input.size cdef shape_type_c shape_input = input.shape diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index 14853cd991e..33abe38c4e2 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -51,6 +51,7 @@ dpnp_det, dpnp_eigh, dpnp_inv, + dpnp_matrix_rank, dpnp_pinv, dpnp_qr, dpnp_slogdet, @@ -397,47 +398,57 @@ def matrix_power(input, count): return call_origin(numpy.linalg.matrix_power, input, count) -def matrix_rank(input, tol=None, hermitian=False): +def matrix_rank(A, tol=None, hermitian=False): """ - Return matrix rank of array. + Return matrix rank of array using SVD method. Rank of the array is the number of singular values of the array that are greater than `tol`. Parameters ---------- - M : {(M,), (..., M, N)} array_like + A : {(M,), (..., M, N)} {dpnp.ndarray, usm_ndarray} Input vector or stack of matrices. - tol : (...) array_like, float, optional + tol : (...) {float, dpnp.ndarray, usm_ndarray}, optional Threshold below which SVD values are considered zero. If `tol` is None, and ``S`` is an array with singular values for `M`, and ``eps`` is the epsilon value for datatype of ``S``, then `tol` is set to ``S.max() * max(M.shape) * eps``. hermitian : bool, optional - If True, `M` is assumed to be Hermitian (symmetric if real-valued), + If True, `A` is assumed to be Hermitian (symmetric if real-valued), enabling a more efficient method for finding singular values. Defaults to False. Returns ------- - rank : (...) array_like - Rank of M. + rank : (...) dpnp.ndarray + Rank of A. - """ + See Also + -------- + :obj:`dpnp.linalg.svd` : Singular Value Decomposition. - x1_desc = dpnp.get_dpnp_descriptor(input, copy_when_nondefault_queue=False) - if x1_desc: - if tol is not None: - pass - elif hermitian: - pass - else: - result_obj = dpnp_matrix_rank(x1_desc).get_pyobj() - result = dpnp.convert_single_elem_array_to_scalar(result_obj) + Examples + -------- + >>> import dpnp as np + >>> from dpnp.linalg import matrix_rank + >>> matrix_rank(np.eye(4)) # Full rank matrix + array(4) + >>> I=np.eye(4); I[-1,-1] = 0. # rank deficient matrix + >>> matrix_rank(I) + array(3) + >>> matrix_rank(np.ones((4,))) # 1 dimension - rank 1 unless all 0 + array(1) + >>> matrix_rank(np.zeros((4,))) + array(0) - return result + """ + + dpnp.check_supported_arrays_type(A) + if tol is not None: + dpnp.check_supported_arrays_type(tol, scalar_type=True) - return call_origin(numpy.linalg.matrix_rank, input, tol, hermitian) + return dpnp_matrix_rank(A, tol=tol, hermitian=hermitian) def multi_dot(arrays, out=None): diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index b92dcae0f47..b8c366f5413 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -39,6 +39,7 @@ "dpnp_det", "dpnp_eigh", "dpnp_inv", + "dpnp_matrix_rank", "dpnp_pinv", "dpnp_qr", "dpnp_slogdet", @@ -999,6 +1000,29 @@ def dpnp_inv(a): return b_f +def dpnp_matrix_rank(A, tol=None, hermitian=False): + """ + dpnp_matrix_rank(A, tol=None, hermitian=False) + + Return matrix rank of array using SVD method. + + """ + + if A.ndim < 2: + return (A != 0).any().astype(int) + + S = dpnp_svd(A, compute_uv=False, hermitian=hermitian) + + if tol is None: + rtol = max(A.shape[-2:]) * dpnp.finfo(S.dtype).eps + tol = S.max(axis=-1, keepdims=True) * rtol + elif not dpnp.isscalar(tol): + # Add a new axis to match Numpy's output + tol = tol[..., None] + + return dpnp.count_nonzero(S > tol, axis=-1) + + def dpnp_pinv(a, rcond=1e-15, hermitian=False): """ dpnp_pinv(a, rcond=1e-15, hermitian=False): diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 5cf226762af..65cfa949db0 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -1,6 +1,7 @@ import dpctl import numpy import pytest +from dpctl.utils import ExecutionPlacementError from numpy.testing import ( assert_allclose, assert_almost_equal, @@ -565,37 +566,113 @@ def test_inv_errors(self): assert_raises(inp.linalg.LinAlgError, inp.linalg.inv, a_dp) -@pytest.mark.parametrize( - "type", get_all_dtypes(no_bool=True, no_complex=True, no_none=True) -) -@pytest.mark.parametrize( - "array", - [ - [0, 0], - [0, 1], - [1, 2], - [[0, 0], [0, 0]], - [[1, 2], [1, 2]], - [[1, 2], [3, 4]], - ], - ids=[ - "[0, 0]", - "[0, 1]", - "[1, 2]", - "[[0, 0], [0, 0]]", - "[[1, 2], [1, 2]]", - "[[1, 2], [3, 4]]", - ], -) -@pytest.mark.parametrize("tol", [None], ids=["None"]) -def test_matrix_rank(type, tol, array): - a = numpy.array(array, dtype=type) - ia = inp.array(a) +class TestMatrixRank: + @pytest.mark.parametrize("dtype", get_all_dtypes()) + @pytest.mark.parametrize( + "data", + [ + numpy.eye(4), + numpy.diag([1, 1, 1, 0]), + numpy.zeros((4, 4)), + numpy.array([1, 0, 0, 0]), + numpy.zeros((4,)), + numpy.array(1), + ], + ) + def test_matrix_rank(self, data, dtype): + a = data.astype(dtype) + a_dp = inp.array(a) - result = inp.linalg.matrix_rank(ia, tol=tol) - expected = numpy.linalg.matrix_rank(a, tol=tol) + np_rank = numpy.linalg.matrix_rank(a) + dp_rank = inp.linalg.matrix_rank(a_dp) + assert np_rank == dp_rank - assert_allclose(expected, result) + @pytest.mark.parametrize("dtype", get_all_dtypes()) + @pytest.mark.parametrize( + "data", + [ + numpy.eye(4), + numpy.ones((4, 4)), + numpy.zeros((4, 4)), + numpy.diag([1, 1, 1, 0]), + ], + ) + def test_matrix_rank_hermitian(self, data, dtype): + a = data.astype(dtype) + a_dp = inp.array(a) + + np_rank = numpy.linalg.matrix_rank(a, hermitian=True) + dp_rank = inp.linalg.matrix_rank(a_dp, hermitian=True) + assert np_rank == dp_rank + + @pytest.mark.parametrize( + "high_tol, low_tol", + [ + (0.99e-6, 1.01e-6), + (numpy.array(0.99e-6), numpy.array(1.01e-6)), + (numpy.array([0.99e-6]), numpy.array([1.01e-6])), + ], + ids=[ + "float", + "0-D array", + "1-D array", + ], + ) + def test_matrix_rank_tolerance(self, high_tol, low_tol): + a = numpy.eye(4) + a[-1, -1] = 1e-6 + a_dp = inp.array(a) + + if isinstance(high_tol, numpy.ndarray): + dp_high_tol = inp.array( + high_tol, usm_type=a_dp.usm_type, sycl_queue=a_dp.sycl_queue + ) + dp_low_tol = inp.array( + low_tol, usm_type=a_dp.usm_type, sycl_queue=a_dp.sycl_queue + ) + else: + dp_high_tol = high_tol + dp_low_tol = low_tol + + np_rank_high_tol = numpy.linalg.matrix_rank( + a, hermitian=True, tol=high_tol + ) + dp_rank_high_tol = inp.linalg.matrix_rank( + a_dp, hermitian=True, tol=dp_high_tol + ) + assert np_rank_high_tol == dp_rank_high_tol + + np_rank_low_tol = numpy.linalg.matrix_rank( + a, hermitian=True, tol=low_tol + ) + dp_rank_low_tol = inp.linalg.matrix_rank( + a_dp, hermitian=True, tol=dp_low_tol + ) + assert np_rank_low_tol == dp_rank_low_tol + + def test_matrix_rank_errors(self): + a_dp = inp.array([[1, 2], [3, 4]], dtype="float32") + + # unsupported type `a` + a_np = inp.asnumpy(a_dp) + assert_raises(TypeError, inp.linalg.matrix_rank, a_np) + + # unsupported type `tol` + tol = numpy.array(0.5, dtype="float32") + assert_raises(TypeError, inp.linalg.matrix_rank, a_dp, tol) + assert_raises(TypeError, inp.linalg.matrix_rank, a_dp, [0.5]) + + # diffetent queue + a_queue = dpctl.SyclQueue() + tol_queue = dpctl.SyclQueue() + a_dp_q = inp.array(a_dp, sycl_queue=a_queue) + tol_dp_q = inp.array([0.5], dtype="float32", sycl_queue=tol_queue) + assert_raises( + ExecutionPlacementError, + inp.linalg.matrix_rank, + a_dp_q, + tol_dp_q, + ) @pytest.mark.usefixtures("allow_fall_back_on_numpy") diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 888891d80f6..dd6a9f5c639 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -1192,20 +1192,36 @@ def test_inv(shape, is_empty, device): assert_sycl_queue_equal(result_queue, expected_queue) +@pytest.mark.parametrize( + "data, tol", + [ + (numpy.array([1, 2]), None), + (numpy.array([[1, 2], [3, 4]]), None), + (numpy.array([[1, 2], [3, 4]]), 1e-06), + ], + ids=[ + "1-D array", + "2-D array no tol", + "2_d array with tol", + ], +) @pytest.mark.parametrize( "device", valid_devices, ids=[device.filter_string for device in valid_devices], ) -def test_matrix_rank(device): - data = [[0, 0], [0, 0]] - numpy_data = numpy.array(data) - dpnp_data = dpnp.array(data, device=device) +def test_matrix_rank(data, tol, device): + dp_data = dpnp.array(data, device=device) - result = dpnp.linalg.matrix_rank(dpnp_data) - expected = numpy.linalg.matrix_rank(numpy_data) + result = dpnp.linalg.matrix_rank(dp_data, tol=tol) + expected = numpy.linalg.matrix_rank(data, tol=tol) assert_array_equal(expected, result) + expected_queue = dp_data.get_array().sycl_queue + result_queue = result.get_array().sycl_queue + + assert_sycl_queue_equal(result_queue, expected_queue) + @pytest.mark.parametrize( "shape", diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 43f526ebcb4..fc9993642eb 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -829,6 +829,27 @@ def test_svd(usm_type, shape, full_matrices_param, compute_uv_param): assert x.usm_type == s.usm_type +@pytest.mark.parametrize( + "data, tol", + [ + (numpy.array([1, 2]), None), + (numpy.array([[1, 2], [3, 4]]), None), + (numpy.array([[1, 2], [3, 4]]), 1e-06), + ], + ids=[ + "1-D array", + "2-D array no tol", + "2_d array with tol", + ], +) +@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) +def test_matrix_rank(data, tol, usm_type): + a = dp.array(data, usm_type=usm_type) + + dp_res = dp.linalg.matrix_rank(a, tol=tol) + assert a.usm_type == dp_res.usm_type + + @pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) @pytest.mark.parametrize( "shape, hermitian", diff --git a/tests/third_party/cupy/linalg_tests/test_norms.py b/tests/third_party/cupy/linalg_tests/test_norms.py index 2ed49d16057..c2ae3fe0dbb 100644 --- a/tests/third_party/cupy/linalg_tests/test_norms.py +++ b/tests/third_party/cupy/linalg_tests/test_norms.py @@ -8,6 +8,36 @@ from tests.third_party.cupy import testing +@testing.parameterize( + *testing.product( + { + "array": [ + [[1, 2], [3, 4]], + [[1, 2], [1, 2]], + [[0, 0], [0, 0]], + [1, 2], + [0, 1], + [0, 0], + ], + "tol": [None, 1], + } + ) +) +class TestMatrixRank(unittest.TestCase): + @testing.for_all_dtypes(no_float16=True) + @testing.numpy_cupy_array_equal(type_check=True) + def test_matrix_rank(self, xp, dtype): + a = xp.array(self.array, dtype=dtype) + y = xp.linalg.matrix_rank(a, tol=self.tol) + if xp is cupy: + assert isinstance(y, cupy.ndarray) + assert y.shape == () + else: + # Note numpy returns numpy scalar or python int + y = xp.array(y) + return y + + # TODO: Remove the use of fixture for all tests in this file # when dpnp.prod() will support complex dtypes on Gen9 @pytest.mark.usefixtures("allow_fall_back_on_numpy")