Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support usm_ndarray batched input for dpnp.linalg #1880

Merged
merged 7 commits into from
Jun 17, 2024
6 changes: 3 additions & 3 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,7 +1354,7 @@ def tensorinv(a, ind=2):
old_shape = a.shape
inv_shape = old_shape[ind:] + old_shape[:ind]
prod = numpy.prod(old_shape[ind:])
a = a.reshape(prod, -1)
a = dpnp.reshape(a, (prod, -1))
a_inv = inv(a)

return a_inv.reshape(*inv_shape)
Expand Down Expand Up @@ -1428,7 +1428,7 @@ def tensorsolve(a, b, axes=None):
"prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])"
)

a = a.reshape(-1, prod)
b = b.ravel()
a = dpnp.reshape(a, (-1, prod))
b = dpnp.ravel(b)
res = solve(a, b)
return res.reshape(old_shape)
20 changes: 10 additions & 10 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _batched_eigh(a, UPLO, eigen_mode, w_type, v_type):
is_cpu_device = a.sycl_device.has_aspect_cpu
orig_shape = a.shape
# get 3d input array by reshape
a = a.reshape(-1, orig_shape[-2], orig_shape[-1])
a = dpnp.reshape(a, (-1, orig_shape[-2], orig_shape[-1]))
a_usm_arr = dpnp.get_usm_ndarray(a)

# allocate a memory for dpnp array of eigenvalues
Expand Down Expand Up @@ -191,7 +191,7 @@ def _batched_inv(a, res_type):

orig_shape = a.shape
# get 3d input arrays by reshape
a = a.reshape(-1, orig_shape[-2], orig_shape[-1])
a = dpnp.reshape(a, (-1, orig_shape[-2], orig_shape[-1]))
batch_size = a.shape[0]
a_usm_arr = dpnp.get_usm_ndarray(a)
a_sycl_queue = a.sycl_queue
Expand Down Expand Up @@ -280,11 +280,11 @@ def _batched_solve(a, b, exec_q, res_usm_type, res_type):
if a.ndim > 3:
# get 3d input arrays by reshape
if a.ndim == b.ndim:
b = b.reshape(-1, b_shape[-2], b_shape[-1])
b = dpnp.reshape(b, (-1, b_shape[-2], b_shape[-1]))
else:
b = b.reshape(-1, b_shape[-1])
b = dpnp.reshape(b, (-1, b_shape[-1]))

a = a.reshape(-1, a_shape[-2], a_shape[-1])
a = dpnp.reshape(a, (-1, a_shape[-2], a_shape[-1]))

a_usm_arr = dpnp.get_usm_ndarray(a)
b_usm_arr = dpnp.get_usm_ndarray(b)
Expand Down Expand Up @@ -386,7 +386,7 @@ def _batched_qr(a, mode="reduced"):
a_sycl_queue = a.sycl_queue

# get 3d input arrays by reshape
a = a.reshape(-1, m, n)
a = dpnp.reshape(a, (-1, m, n))

a = a.swapaxes(-2, -1)
a_usm_arr = dpnp.get_usm_ndarray(a)
Expand Down Expand Up @@ -537,7 +537,7 @@ def _batched_svd(

if a.ndim > 3:
# get 3d input arrays by reshape
a = a.reshape(prod(a.shape[:-2]), a.shape[-2], a.shape[-1])
a = dpnp.reshape(a, (prod(a.shape[:-2]), a.shape[-2], a.shape[-1]))
reshape = True

batch_size = a.shape[0]
Expand Down Expand Up @@ -830,7 +830,7 @@ def _lu_factor(a, res_type):
if a.ndim > 2:
orig_shape = a.shape
# get 3d input arrays by reshape
a = a.reshape(-1, n, n)
a = dpnp.reshape(a, (-1, n, n))
batch_size = a.shape[0]
a_usm_arr = dpnp.get_usm_ndarray(a)

Expand Down Expand Up @@ -1743,7 +1743,7 @@ def dpnp_cholesky_batch(a, upper_lower, res_type):

orig_shape = a.shape
# get 3d input arrays by reshape
a = a.reshape(-1, n, n)
a = dpnp.reshape(a, (-1, n, n))
batch_size = a.shape[0]
a_usm_arr = dpnp.get_usm_ndarray(a)

Expand Down Expand Up @@ -2171,7 +2171,7 @@ def dpnp_matrix_power(a, n):
# `result` will hold the final matrix power,
# while `acc` serves as an accumulator for the intermediate matrix powers.
result = None
acc = a.copy()
acc = dpnp.copy(a)
while n > 0:
n, bit = divmod(n, 2)
if bit:
Expand Down
66 changes: 66 additions & 0 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,72 @@ def vvsort(val, vec, size, xp):
vec[:, imax] = temp


@pytest.mark.parametrize(
"func, gen_kwargs, func_kwargs",
[
pytest.param("cholesky", {"hermitian": True}, {}),
pytest.param("cond", {}, {}),
pytest.param("det", {}, {}),
pytest.param("eig", {}, {}),
pytest.param("eigh", {"hermitian": True}, {}),
pytest.param("eigvals", {}, {}),
pytest.param("eigvalsh", {"hermitian": True}, {}),
pytest.param("inv", {}, {}),
pytest.param("matrix_power", {}, {"n": 4}),
pytest.param("matrix_rank", {}, {}),
pytest.param("norm", {}, {}),
pytest.param("pinv", {}, {}),
pytest.param("qr", {}, {}),
pytest.param("slogdet", {}, {}),
pytest.param("solve", {}, {}),
pytest.param("svd", {}, {}),
pytest.param("tensorinv", {}, {"ind": 1}),
pytest.param("tensorsolve", {}, {}),
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
],
)
def test_usm_ndarray_input_batch(func, gen_kwargs, func_kwargs):
shape = (
(2, 2, 3, 3) if func not in ["tensorinv", "tensorsolve"] else (4, 2, 2)
)

if func in ["lstsq", "solve", "tensorsolve"]:
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
if func == "tensorsolve":
shape_b = (4,)
dpt_args = [
dpt.asarray(
generate_random_numpy_array(
shape, seed_value=81, **gen_kwargs
)
),
dpt.asarray(
generate_random_numpy_array(
shape_b, seed_value=81, **gen_kwargs
)
),
]
else:
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
dpt_args = [
dpt.asarray(
generate_random_numpy_array(
shape, seed_value=81, **gen_kwargs
)
)
for _ in range(2)
]
else:
dpt_args = [
dpt.asarray(generate_random_numpy_array(shape, **gen_kwargs))
]

result = getattr(inp.linalg, func)(*dpt_args, **func_kwargs)

if isinstance(result, tuple):
for res in result:
assert isinstance(res, inp.ndarray)
else:
assert isinstance(result, inp.ndarray)


class TestCholesky:
@pytest.mark.parametrize(
"array",
Expand Down
Loading