From 96a2e4111fe025d66d4a6b795cd6ff5f45de8afd Mon Sep 17 00:00:00 2001 From: vlad-perevezentsev Date: Mon, 17 Jun 2024 14:34:33 +0200 Subject: [PATCH] Support usm_ndarray batched input for dpnp.linalg (#1880) * Add usm_ndarray input support for linalg * Add test_usm_ndarray_input_batch to test_linalg.py * Add usm_ndarray input support for dpnp_iface_linearalgebra * Add test_usm_ndarray_linearalgebra_batch to test_linalg.py * Apply comments --------- Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com> --- dpnp/dpnp_iface_linearalgebra.py | 12 ++--- dpnp/linalg/dpnp_iface_linalg.py | 6 +-- dpnp/linalg/dpnp_utils_linalg.py | 20 ++++---- tests/test_linalg.py | 81 ++++++++++++++++++++++++++++++++ 4 files changed, 100 insertions(+), 19 deletions(-) diff --git a/dpnp/dpnp_iface_linearalgebra.py b/dpnp/dpnp_iface_linearalgebra.py index f674c96040a..aef9203746f 100644 --- a/dpnp/dpnp_iface_linearalgebra.py +++ b/dpnp/dpnp_iface_linearalgebra.py @@ -892,13 +892,13 @@ def outer(a, b, out=None): dpnp.check_supported_arrays_type(a, b, scalar_type=True, all_scalars=False) if dpnp.isscalar(a): x1 = a - x2 = b.ravel()[None, :] + x2 = dpnp.ravel(b)[None, :] elif dpnp.isscalar(b): - x1 = a.ravel()[:, None] + x1 = dpnp.ravel(a)[:, None] x2 = b else: - x1 = a.ravel() - x2 = b.ravel() + x1 = dpnp.ravel(a) + x2 = dpnp.ravel(b) return dpnp.multiply.outer(x1, x2, out=out) @@ -1056,8 +1056,8 @@ def tensordot(a, b, axes=2): newshape_b = (n1, n2) oldb = [b_shape[axis] for axis in notin] - at = a.transpose(newaxes_a).reshape(newshape_a) - bt = b.transpose(newaxes_b).reshape(newshape_b) + at = dpnp.transpose(a, newaxes_a).reshape(newshape_a) + bt = dpnp.transpose(b, newaxes_b).reshape(newshape_b) res = dpnp.matmul(at, bt) return res.reshape(olda + oldb) diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index 72d79ad329d..5342daa1758 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -1354,7 +1354,7 @@ def tensorinv(a, ind=2): old_shape = a.shape inv_shape = old_shape[ind:] + old_shape[:ind] prod = numpy.prod(old_shape[ind:]) - a = a.reshape(prod, -1) + a = dpnp.reshape(a, (prod, -1)) a_inv = inv(a) return a_inv.reshape(*inv_shape) @@ -1428,7 +1428,7 @@ def tensorsolve(a, b, axes=None): "prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])" ) - a = a.reshape(-1, prod) - b = b.ravel() + a = dpnp.reshape(a, (-1, prod)) + b = dpnp.ravel(b) res = solve(a, b) return res.reshape(old_shape) diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 22aa396c7fe..10e297937ee 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -99,7 +99,7 @@ def _batched_eigh(a, UPLO, eigen_mode, w_type, v_type): is_cpu_device = a.sycl_device.has_aspect_cpu orig_shape = a.shape # get 3d input array by reshape - a = a.reshape(-1, orig_shape[-2], orig_shape[-1]) + a = dpnp.reshape(a, (-1, orig_shape[-2], orig_shape[-1])) a_usm_arr = dpnp.get_usm_ndarray(a) # allocate a memory for dpnp array of eigenvalues @@ -191,7 +191,7 @@ def _batched_inv(a, res_type): orig_shape = a.shape # get 3d input arrays by reshape - a = a.reshape(-1, orig_shape[-2], orig_shape[-1]) + a = dpnp.reshape(a, (-1, orig_shape[-2], orig_shape[-1])) batch_size = a.shape[0] a_usm_arr = dpnp.get_usm_ndarray(a) a_sycl_queue = a.sycl_queue @@ -280,11 +280,11 @@ def _batched_solve(a, b, exec_q, res_usm_type, res_type): if a.ndim > 3: # get 3d input arrays by reshape if a.ndim == b.ndim: - b = b.reshape(-1, b_shape[-2], b_shape[-1]) + b = dpnp.reshape(b, (-1, b_shape[-2], b_shape[-1])) else: - b = b.reshape(-1, b_shape[-1]) + b = dpnp.reshape(b, (-1, b_shape[-1])) - a = a.reshape(-1, a_shape[-2], a_shape[-1]) + a = dpnp.reshape(a, (-1, a_shape[-2], a_shape[-1])) a_usm_arr = dpnp.get_usm_ndarray(a) b_usm_arr = dpnp.get_usm_ndarray(b) @@ -386,7 +386,7 @@ def _batched_qr(a, mode="reduced"): a_sycl_queue = a.sycl_queue # get 3d input arrays by reshape - a = a.reshape(-1, m, n) + a = dpnp.reshape(a, (-1, m, n)) a = a.swapaxes(-2, -1) a_usm_arr = dpnp.get_usm_ndarray(a) @@ -537,7 +537,7 @@ def _batched_svd( if a.ndim > 3: # get 3d input arrays by reshape - a = a.reshape(prod(a.shape[:-2]), a.shape[-2], a.shape[-1]) + a = dpnp.reshape(a, (prod(a.shape[:-2]), a.shape[-2], a.shape[-1])) reshape = True batch_size = a.shape[0] @@ -830,7 +830,7 @@ def _lu_factor(a, res_type): if a.ndim > 2: orig_shape = a.shape # get 3d input arrays by reshape - a = a.reshape(-1, n, n) + a = dpnp.reshape(a, (-1, n, n)) batch_size = a.shape[0] a_usm_arr = dpnp.get_usm_ndarray(a) @@ -1743,7 +1743,7 @@ def dpnp_cholesky_batch(a, upper_lower, res_type): orig_shape = a.shape # get 3d input arrays by reshape - a = a.reshape(-1, n, n) + a = dpnp.reshape(a, (-1, n, n)) batch_size = a.shape[0] a_usm_arr = dpnp.get_usm_ndarray(a) @@ -2171,7 +2171,7 @@ def dpnp_matrix_power(a, n): # `result` will hold the final matrix power, # while `acc` serves as an accumulator for the intermediate matrix powers. result = None - acc = a.copy() + acc = dpnp.copy(a) while n > 0: n, bit = divmod(n, 2) if bit: diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 48a4891034c..b718a2cec87 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -57,6 +57,87 @@ def vvsort(val, vec, size, xp): vec[:, imax] = temp +# check linear algebra functions from dpnp.linalg +# with multidimensional usm_ndarray as input +@pytest.mark.parametrize( + "func, gen_kwargs, func_kwargs", + [ + pytest.param("cholesky", {"hermitian": True}, {}), + pytest.param("cond", {}, {}), + pytest.param("det", {}, {}), + pytest.param("eig", {}, {}), + pytest.param("eigh", {"hermitian": True}, {}), + pytest.param("eigvals", {}, {}), + pytest.param("eigvalsh", {"hermitian": True}, {}), + pytest.param("inv", {}, {}), + pytest.param("matrix_power", {}, {"n": 4}), + pytest.param("matrix_rank", {}, {}), + pytest.param("norm", {}, {}), + pytest.param("pinv", {}, {}), + pytest.param("qr", {}, {}), + pytest.param("slogdet", {}, {}), + pytest.param("solve", {}, {}), + pytest.param("svd", {}, {}), + pytest.param("tensorinv", {}, {"ind": 1}), + pytest.param("tensorsolve", {}, {}), + ], +) +def test_usm_ndarray_linalg_batch(func, gen_kwargs, func_kwargs): + shape = ( + (2, 2, 3, 3) if func not in ["tensorinv", "tensorsolve"] else (4, 2, 2) + ) + + if func == "tensorsolve": + shape_b = (4,) + dpt_args = [ + dpt.asarray( + generate_random_numpy_array(shape, seed_value=81, **gen_kwargs) + ), + dpt.asarray( + generate_random_numpy_array( + shape_b, seed_value=81, **gen_kwargs + ) + ), + ] + elif func in ["lstsq", "solve"]: + dpt_args = [ + dpt.asarray( + generate_random_numpy_array(shape, seed_value=81, **gen_kwargs) + ) + for _ in range(2) + ] + else: + dpt_args = [ + dpt.asarray(generate_random_numpy_array(shape, **gen_kwargs)) + ] + + result = getattr(inp.linalg, func)(*dpt_args, **func_kwargs) + + if isinstance(result, tuple): + for res in result: + assert isinstance(res, inp.ndarray) + else: + assert isinstance(result, inp.ndarray) + + +# check linear algebra functions from dpnp +# with multidimensional usm_ndarray as input +@pytest.mark.parametrize( + "func", ["dot", "inner", "kron", "matmul", "outer", "tensordot", "vdot"] +) +def test_usm_ndarray_linearalgebra_batch(func): + shape = (2, 2, 2, 2) + + dpt_args = [ + dpt.asarray(generate_random_numpy_array(shape, seed_value=81)) + for _ in range(2) + ] + + result = getattr(inp, func)(*dpt_args) + + assert isinstance(result, inp.ndarray) + + class TestCholesky: @pytest.mark.parametrize( "array",