Skip to content

Implementing repeat function #875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
3 changes: 0 additions & 3 deletions ci/Numba-array-api-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ array_api_tests/test_creation_functions.py::test_empty_like
array_api_tests/test_data_type_functions.py::test_finfo[complex64]
array_api_tests/test_manipulation_functions.py::test_squeeze
array_api_tests/test_has_names.py::test_has_names[utility-diff]
array_api_tests/test_has_names.py::test_has_names[manipulation-repeat]
array_api_tests/test_has_names.py::test_has_names[manipulation-tile]
array_api_tests/test_has_names.py::test_has_names[manipulation-unstack]
array_api_tests/test_has_names.py::test_has_names[statistical-cumulative_sum]
Expand All @@ -79,7 +78,6 @@ array_api_tests/test_has_names.py::test_has_names[indexing-take_along_axis]
array_api_tests/test_has_names.py::test_has_names[searching-count_nonzero]
array_api_tests/test_has_names.py::test_has_names[searching-searchsorted]
array_api_tests/test_signatures.py::test_func_signature[diff]
array_api_tests/test_signatures.py::test_func_signature[repeat]
array_api_tests/test_signatures.py::test_func_signature[tile]
array_api_tests/test_signatures.py::test_func_signature[unstack]
array_api_tests/test_signatures.py::test_func_signature[take_along_axis]
Expand Down Expand Up @@ -107,7 +105,6 @@ array_api_tests/test_statistical_functions.py::test_cumulative_sum
array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_1[None]
array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_2[1]
array_api_tests/test_array_object.py::test_getitem_arrays_and_ints_2[None]
array_api_tests/test_manipulation_functions.py::test_repeat
array_api_tests/test_searching_functions.py::test_count_nonzero
array_api_tests/test_searching_functions.py::test_searchsorted
array_api_tests/test_manipulation_functions.py::test_tile
Expand Down
2 changes: 2 additions & 0 deletions sparse/numba_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
permute_dims,
prod,
real,
repeat,
reshape,
round,
squeeze,
Expand Down Expand Up @@ -335,6 +336,7 @@
"where",
"zeros",
"zeros_like",
"repeat",
]


Expand Down
45 changes: 44 additions & 1 deletion sparse/numba_backend/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np

from ._coo import as_coo
from ._coo import as_coo, expand_dims
from ._sparse_array import SparseArray
from ._utils import (
_zero_of_dtype,
Expand Down Expand Up @@ -3104,3 +3104,46 @@ def vecdot(x1, x2, /, *, axis=-1):
x1 = np.conjugate(x1)

return np.sum(x1 * x2, axis=axis, dtype=np.result_type(x1, x2))


def repeat(a, repeats, axis=None):
"""
Repeat each element of an array after themselves

Parameters
----------
a : SparseArray
Input sparse arrays
repeats : int
The number of repetitions for each element.
(Uneven repeats are not yet Implemented.)
axis : int, optional
The axis along which to repeat values. Returns a flattened sparse array if not specified.

Returns
-------
out : SparseArray
A sparse array which has the same shape as a, except along the given axis.
"""
if not isinstance(a, SparseArray):
raise TypeError("`a` must be a SparseArray.")

if not isinstance(repeats, int):
raise ValueError("`repeats` must be an integer, uneven repeats are not yet Implemented.")
axes = list(range(a.ndim))
new_shape = list(a.shape)
axis_is_none = False
if axis is None:
a = a.reshape(-1)
axis = 0
axis_is_none = True
if axis < 0:
axis = a.ndim + axis
axes[a.ndim - 1], axes[axis] = axes[axis], axes[a.ndim - 1]
new_shape[axis] *= repeats
a = expand_dims(a, axis=axis + 1)
shape_to_broadcast = a.shape[: axis + 1] + (a.shape[axis + 1] * repeats,) + a.shape[axis + 2 :]
a = broadcast_to(a, shape_to_broadcast)
if not axis_is_none:
return a.reshape(new_shape)
return a.reshape(new_shape).flatten()
30 changes: 30 additions & 0 deletions sparse/numba_backend/tests/test_coo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1926,3 +1926,33 @@ def test_xH_x():
assert_eq(Ysp.conj().T @ Y, Y.conj().T @ Y)
assert_eq(Ysp.conj().T @ Ysp, Y.conj().T @ Y)
assert_eq(Y.conj().T @ Ysp.conj().T, Y.conj().T @ Y.conj().T)


def test_repeat_invalid_input():
a = np.eye(3)
with pytest.raises(TypeError, match="`a` must be a SparseArray"):
sparse.repeat(a, repeats=2)
with pytest.raises(ValueError, match="`repeats` must be an integer"):
sparse.repeat(COO.from_numpy(a), repeats=[2, 2, 2])


@pytest.mark.parametrize("ndim", range(1, 5))
@pytest.mark.parametrize("repeats", [1, 2, 3])
def test_repeat(ndim, repeats):
rng = np.random.default_rng()
shape = tuple(rng.integers(1, 4) for _ in range(ndim))
a = rng.integers(1, 10, size=shape)
sparse_a = COO.from_numpy(a)
for axis in [*range(-ndim, ndim), None]:
expected = np.repeat(a, repeats=repeats, axis=axis)
result_sparse = sparse.repeat(sparse_a, repeats=repeats, axis=axis)
actual = result_sparse.todense()
assert actual.shape == expected.shape, f"Shape mismatch on axis {axis}: {actual.shape} vs {expected.shape}"
np.testing.assert_array_equal(actual, expected)

expected = np.repeat(a, repeats=repeats, axis=None)
result_sparse = sparse.repeat(sparse_a, repeats=repeats, axis=None)
actual = result_sparse.todense()
print(f"Expected: {expected}, Actual: {actual}")
assert actual.shape == expected.shape
np.testing.assert_array_equal(actual, expected)
1 change: 1 addition & 0 deletions sparse/numba_backend/tests/test_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def test_namespace():
"real",
"reciprocal",
"remainder",
"repeat",
"reshape",
"result_type",
"roll",
Expand Down
Loading