Skip to content

Commit

Permalink
Support usm_ndarray batched input for dpnp.linalg (#1880)
Browse files Browse the repository at this point in the history
* Add usm_ndarray input support for linalg

* Add test_usm_ndarray_input_batch to test_linalg.py

* Add usm_ndarray input support for dpnp_iface_linearalgebra

* Add test_usm_ndarray_linearalgebra_batch to test_linalg.py

* Apply comments

---------

Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
  • Loading branch information
vlad-perevezentsev and antonwolfy authored Jun 17, 2024
1 parent a813fae commit 96a2e41
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 19 deletions.
12 changes: 6 additions & 6 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,13 +892,13 @@ def outer(a, b, out=None):
dpnp.check_supported_arrays_type(a, b, scalar_type=True, all_scalars=False)
if dpnp.isscalar(a):
x1 = a
x2 = b.ravel()[None, :]
x2 = dpnp.ravel(b)[None, :]
elif dpnp.isscalar(b):
x1 = a.ravel()[:, None]
x1 = dpnp.ravel(a)[:, None]
x2 = b
else:
x1 = a.ravel()
x2 = b.ravel()
x1 = dpnp.ravel(a)
x2 = dpnp.ravel(b)

return dpnp.multiply.outer(x1, x2, out=out)

Expand Down Expand Up @@ -1056,8 +1056,8 @@ def tensordot(a, b, axes=2):
newshape_b = (n1, n2)
oldb = [b_shape[axis] for axis in notin]

at = a.transpose(newaxes_a).reshape(newshape_a)
bt = b.transpose(newaxes_b).reshape(newshape_b)
at = dpnp.transpose(a, newaxes_a).reshape(newshape_a)
bt = dpnp.transpose(b, newaxes_b).reshape(newshape_b)
res = dpnp.matmul(at, bt)

return res.reshape(olda + oldb)
Expand Down
6 changes: 3 additions & 3 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,7 +1354,7 @@ def tensorinv(a, ind=2):
old_shape = a.shape
inv_shape = old_shape[ind:] + old_shape[:ind]
prod = numpy.prod(old_shape[ind:])
a = a.reshape(prod, -1)
a = dpnp.reshape(a, (prod, -1))
a_inv = inv(a)

return a_inv.reshape(*inv_shape)
Expand Down Expand Up @@ -1428,7 +1428,7 @@ def tensorsolve(a, b, axes=None):
"prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])"
)

a = a.reshape(-1, prod)
b = b.ravel()
a = dpnp.reshape(a, (-1, prod))
b = dpnp.ravel(b)
res = solve(a, b)
return res.reshape(old_shape)
20 changes: 10 additions & 10 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _batched_eigh(a, UPLO, eigen_mode, w_type, v_type):
is_cpu_device = a.sycl_device.has_aspect_cpu
orig_shape = a.shape
# get 3d input array by reshape
a = a.reshape(-1, orig_shape[-2], orig_shape[-1])
a = dpnp.reshape(a, (-1, orig_shape[-2], orig_shape[-1]))
a_usm_arr = dpnp.get_usm_ndarray(a)

# allocate a memory for dpnp array of eigenvalues
Expand Down Expand Up @@ -191,7 +191,7 @@ def _batched_inv(a, res_type):

orig_shape = a.shape
# get 3d input arrays by reshape
a = a.reshape(-1, orig_shape[-2], orig_shape[-1])
a = dpnp.reshape(a, (-1, orig_shape[-2], orig_shape[-1]))
batch_size = a.shape[0]
a_usm_arr = dpnp.get_usm_ndarray(a)
a_sycl_queue = a.sycl_queue
Expand Down Expand Up @@ -280,11 +280,11 @@ def _batched_solve(a, b, exec_q, res_usm_type, res_type):
if a.ndim > 3:
# get 3d input arrays by reshape
if a.ndim == b.ndim:
b = b.reshape(-1, b_shape[-2], b_shape[-1])
b = dpnp.reshape(b, (-1, b_shape[-2], b_shape[-1]))
else:
b = b.reshape(-1, b_shape[-1])
b = dpnp.reshape(b, (-1, b_shape[-1]))

a = a.reshape(-1, a_shape[-2], a_shape[-1])
a = dpnp.reshape(a, (-1, a_shape[-2], a_shape[-1]))

a_usm_arr = dpnp.get_usm_ndarray(a)
b_usm_arr = dpnp.get_usm_ndarray(b)
Expand Down Expand Up @@ -386,7 +386,7 @@ def _batched_qr(a, mode="reduced"):
a_sycl_queue = a.sycl_queue

# get 3d input arrays by reshape
a = a.reshape(-1, m, n)
a = dpnp.reshape(a, (-1, m, n))

a = a.swapaxes(-2, -1)
a_usm_arr = dpnp.get_usm_ndarray(a)
Expand Down Expand Up @@ -537,7 +537,7 @@ def _batched_svd(

if a.ndim > 3:
# get 3d input arrays by reshape
a = a.reshape(prod(a.shape[:-2]), a.shape[-2], a.shape[-1])
a = dpnp.reshape(a, (prod(a.shape[:-2]), a.shape[-2], a.shape[-1]))
reshape = True

batch_size = a.shape[0]
Expand Down Expand Up @@ -830,7 +830,7 @@ def _lu_factor(a, res_type):
if a.ndim > 2:
orig_shape = a.shape
# get 3d input arrays by reshape
a = a.reshape(-1, n, n)
a = dpnp.reshape(a, (-1, n, n))
batch_size = a.shape[0]
a_usm_arr = dpnp.get_usm_ndarray(a)

Expand Down Expand Up @@ -1743,7 +1743,7 @@ def dpnp_cholesky_batch(a, upper_lower, res_type):

orig_shape = a.shape
# get 3d input arrays by reshape
a = a.reshape(-1, n, n)
a = dpnp.reshape(a, (-1, n, n))
batch_size = a.shape[0]
a_usm_arr = dpnp.get_usm_ndarray(a)

Expand Down Expand Up @@ -2171,7 +2171,7 @@ def dpnp_matrix_power(a, n):
# `result` will hold the final matrix power,
# while `acc` serves as an accumulator for the intermediate matrix powers.
result = None
acc = a.copy()
acc = dpnp.copy(a)
while n > 0:
n, bit = divmod(n, 2)
if bit:
Expand Down
81 changes: 81 additions & 0 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,87 @@ def vvsort(val, vec, size, xp):
vec[:, imax] = temp


# check linear algebra functions from dpnp.linalg
# with multidimensional usm_ndarray as input
@pytest.mark.parametrize(
"func, gen_kwargs, func_kwargs",
[
pytest.param("cholesky", {"hermitian": True}, {}),
pytest.param("cond", {}, {}),
pytest.param("det", {}, {}),
pytest.param("eig", {}, {}),
pytest.param("eigh", {"hermitian": True}, {}),
pytest.param("eigvals", {}, {}),
pytest.param("eigvalsh", {"hermitian": True}, {}),
pytest.param("inv", {}, {}),
pytest.param("matrix_power", {}, {"n": 4}),
pytest.param("matrix_rank", {}, {}),
pytest.param("norm", {}, {}),
pytest.param("pinv", {}, {}),
pytest.param("qr", {}, {}),
pytest.param("slogdet", {}, {}),
pytest.param("solve", {}, {}),
pytest.param("svd", {}, {}),
pytest.param("tensorinv", {}, {"ind": 1}),
pytest.param("tensorsolve", {}, {}),
],
)
def test_usm_ndarray_linalg_batch(func, gen_kwargs, func_kwargs):
shape = (
(2, 2, 3, 3) if func not in ["tensorinv", "tensorsolve"] else (4, 2, 2)
)

if func == "tensorsolve":
shape_b = (4,)
dpt_args = [
dpt.asarray(
generate_random_numpy_array(shape, seed_value=81, **gen_kwargs)
),
dpt.asarray(
generate_random_numpy_array(
shape_b, seed_value=81, **gen_kwargs
)
),
]
elif func in ["lstsq", "solve"]:
dpt_args = [
dpt.asarray(
generate_random_numpy_array(shape, seed_value=81, **gen_kwargs)
)
for _ in range(2)
]
else:
dpt_args = [
dpt.asarray(generate_random_numpy_array(shape, **gen_kwargs))
]

result = getattr(inp.linalg, func)(*dpt_args, **func_kwargs)

if isinstance(result, tuple):
for res in result:
assert isinstance(res, inp.ndarray)
else:
assert isinstance(result, inp.ndarray)


# check linear algebra functions from dpnp
# with multidimensional usm_ndarray as input
@pytest.mark.parametrize(
"func", ["dot", "inner", "kron", "matmul", "outer", "tensordot", "vdot"]
)
def test_usm_ndarray_linearalgebra_batch(func):
shape = (2, 2, 2, 2)

dpt_args = [
dpt.asarray(generate_random_numpy_array(shape, seed_value=81))
for _ in range(2)
]

result = getattr(inp, func)(*dpt_args)

assert isinstance(result, inp.ndarray)


class TestCholesky:
@pytest.mark.parametrize(
"array",
Expand Down

0 comments on commit 96a2e41

Please sign in to comment.