Skip to content

Add stats.is_constant #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
80171a3
WIP is-constant
flying-sheep Feb 18, 2025
319cb87
docs
flying-sheep Feb 18, 2025
042b50b
support sparse
flying-sheep Feb 18, 2025
05fbe8c
fix h5py and zarr axis=None
flying-sheep Feb 18, 2025
f49a797
undo h5 and zarr for now
flying-sheep Feb 18, 2025
dedd75a
Merge branch 'main' into pa/is-constant
flying-sheep Feb 20, 2025
846e7d6
fix mypy
flying-sheep Feb 20, 2025
3133c60
Merge branch 'main' into pa/is-constant
flying-sheep Feb 20, 2025
7fefddb
doctests
flying-sheep Feb 20, 2025
edde305
fix tests
flying-sheep Feb 20, 2025
2141694
Merge branch 'main' into pa/is-constant
flying-sheep Feb 21, 2025
e06cd9d
doc
flying-sheep Feb 21, 2025
2e58b4e
Merge branch 'main' into pa/is-constant
flying-sheep Feb 25, 2025
7a44781
generalize benchmark
flying-sheep Feb 25, 2025
d6a1fa6
some renames
flying-sheep Feb 25, 2025
7d6ac57
fewer variables
flying-sheep Feb 25, 2025
19a4187
extract validation
flying-sheep Feb 25, 2025
4619e57
Merge branch 'main' into pa/is-constant
flying-sheep Feb 25, 2025
2e767f4
less type:ignore
flying-sheep Feb 25, 2025
79150e8
Merge branch 'main' into pa/is-constant
flying-sheep Feb 25, 2025
64c2baa
actually test dask with sparse as well
flying-sheep Feb 25, 2025
b53742a
fix typing?
flying-sheep Feb 25, 2025
b4639e2
use np.bool
flying-sheep Feb 27, 2025
c8f6fc2
fix is_constant
flying-sheep Feb 27, 2025
f498da9
test that all blocks are individually constant
flying-sheep Feb 27, 2025
dc7b2d6
add viz
flying-sheep Feb 27, 2025
86664ba
asymmetric
flying-sheep Feb 27, 2025
521fb63
rename test
flying-sheep Feb 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
)
# Try overriding type paths
qualname_overrides = autodoc_type_aliases = {
"np.bool": ("py:data", "numpy.bool"),
"np.dtype": "numpy.dtype",
"np.number": "numpy.number",
"np.integer": "numpy.integer",
Expand Down Expand Up @@ -94,6 +95,7 @@
("py:class", "ArrayLike"),
("py:class", "DTypeLike"),
("py:class", "NDArray"),
("py:class", "np.bool"),
("py:class", "_pytest.fixtures.FixtureRequest"),
}

Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ scripts.clean = "git clean -fdX docs"

[tool.hatch.envs.hatch-test]
features = [ "test" ]
extra-dependencies = [ "ipykernel" ]
extra-dependencies = [ "ipykernel", "ipycytoscape" ]
env-vars.CODSPEED_PROFILE_FOLDER = "test-data/codspeed"
overrides.matrix.extras.features = [
{ if = [ "full" ], value = "full" },
Expand Down Expand Up @@ -106,14 +106,17 @@ lint.flake8-type-checking.strict = true
lint.isort.known-first-party = [ "fast_array_utils" ]
lint.isort.lines-after-imports = 2
lint.isort.required-imports = [ "from __future__ import annotations" ]
lint.pydocstyle.convention = "numpy"

[tool.pytest.ini_options]
addopts = [
"--import-mode=importlib",
"--strict-markers",
"--doctest-modules",
"--pyargs",
"-ptesting.fast_array_utils.pytest",
]
testpaths = [ "./tests", "fast_array_utils" ]
filterwarnings = [
"error",
# codspeed seems to break this dtype added by h5py
Expand Down
15 changes: 15 additions & 0 deletions src/fast_array_utils/_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations

from numbers import Integral


def validate_axis(axis: int | None) -> None:
if axis is None:
return
if not isinstance(axis, Integral): # pragma: no cover
msg = "axis must be integer or None."
raise TypeError(msg)
if axis not in (0, 1): # pragma: no cover
msg = "We only support axis 0 and 1 at the moment"
raise NotImplementedError(msg)
3 changes: 2 additions & 1 deletion src/fast_array_utils/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from __future__ import annotations

from ._is_constant import is_constant
from ._sum import sum


__all__ = ["sum"]
__all__ = ["is_constant", "sum"]
156 changes: 156 additions & 0 deletions src/fast_array_utils/stats/_is_constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations

from functools import partial, singledispatch
from typing import TYPE_CHECKING, Any, cast, overload

import numba
import numpy as np
from numpy.typing import NDArray

from .. import types
from .._validation import validate_axis


if TYPE_CHECKING:
from collections.abc import Callable
from typing import Literal, TypeVar

C = TypeVar("C", bound=Callable[..., Any])


@overload
def is_constant(a: types.DaskArray, /, *, axis: Literal[0, 1, None] = None) -> types.DaskArray: ...
@overload
def is_constant(a: NDArray[Any] | types.CSBase, /, *, axis: None = None) -> bool: ...
@overload
def is_constant(a: NDArray[Any] | types.CSBase, /, *, axis: Literal[0, 1]) -> NDArray[np.bool]: ...


def is_constant(
a: NDArray[Any] | types.CSBase | types.DaskArray, /, *, axis: Literal[0, 1, None] = None
) -> bool | NDArray[np.bool] | types.DaskArray:
"""Check whether values in array are constant.

Params
------
a
Array to check
axis
Axis to reduce over.

Returns
-------
If ``axis`` is :data:`None`, return if all values were constant.
Else returns a boolean array with :data:`True` representing constant columns/rows.

Example
-------
>>> a = np.array([[0, 1], [0, 0]])
>>> a
array([[0, 1],
[0, 0]])
>>> is_constant(a)
False
>>> is_constant(a, axis=0)
array([ True, False])
>>> is_constant(a, axis=1)
array([False, True])

"""
validate_axis(axis)
return _is_constant(a, axis=axis)


@singledispatch
def _is_constant(
a: NDArray[Any] | types.CSBase | types.DaskArray, /, *, axis: Literal[0, 1, None] = None
) -> bool | NDArray[np.bool]: # pragma: no cover
raise NotImplementedError


@_is_constant.register(np.ndarray)
def _(a: NDArray[Any], /, *, axis: Literal[0, 1, None] = None) -> bool | NDArray[np.bool]:
# Should eventually support nd, not now.
match axis:
case None:
return bool((a == a.flat[0]).all())
case 0:
return _is_constant_rows(a.T)
case 1:
return _is_constant_rows(a)


def _is_constant_rows(a: NDArray[Any]) -> NDArray[np.bool]:
b = np.broadcast_to(a[:, 0][:, np.newaxis], a.shape)
return cast(NDArray[np.bool], (a == b).all(axis=1))


@_is_constant.register(types.CSBase)
def _(a: types.CSBase, /, *, axis: Literal[0, 1, None] = None) -> bool | NDArray[np.bool]:
n_row, n_col = a.shape
if axis is None:
if len(a.data) == n_row * n_col:
return is_constant(cast(NDArray[Any], a.data))
return bool((a.data == 0).all())
shape = (n_row, n_col) if axis == 1 else (n_col, n_row)
match axis, a.format:
case 0, "csr":
a = a.T.tocsr()
case 1, "csc":
a = a.T.tocsc()
return _is_constant_csr_rows(a.data, a.indptr, shape)


@numba.njit(cache=True)
def _is_constant_csr_rows(
data: NDArray[np.number[Any]],
indptr: NDArray[np.integer[Any]],
shape: tuple[int, int],
) -> NDArray[np.bool]:
n = len(indptr) - 1
result = np.ones(n, dtype=np.bool)
for i in numba.prange(n):
start = indptr[i]
stop = indptr[i + 1]
val = data[start] if stop - start == shape[1] else 0
for j in range(start, stop):
if data[j] != val:
result[i] = False
break
return result


@_is_constant.register(types.DaskArray)
def _(a: types.DaskArray, /, *, axis: Literal[0, 1, None] = None) -> types.DaskArray:
if TYPE_CHECKING:
from dask.array.core import map_blocks
from dask.array.overlap import map_overlap
else:
from dask.array import map_blocks, map_overlap

if axis is not None:
return cast(
types.DaskArray,
map_blocks( # type: ignore[no-untyped-call]
partial(is_constant, axis=axis), a, drop_axis=axis, meta=np.array([], dtype=np.bool)
),
)

rv = cast(
types.DaskArray,
(a == a[0, 0].compute()).all()
if isinstance(a._meta, np.ndarray) # noqa: SLF001
else map_overlap( # type: ignore[no-untyped-call]
lambda a: np.array([[is_constant(a)]]),
a,
# use asymmetric overlaps to avoid unnecessary computation
depth={d: (0, 1) for d in range(a.ndim)},
trim=False,
meta=np.array([], dtype=bool),
).all(),
)
return cast(
types.DaskArray,
map_blocks(bool, rv, meta=np.array([], dtype=bool)), # type: ignore[no-untyped-call]
)
6 changes: 4 additions & 2 deletions src/fast_array_utils/stats/_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from numpy.typing import NDArray

from .. import types
from .._validation import validate_axis


if TYPE_CHECKING:
Expand All @@ -26,7 +27,7 @@ def sum(
) -> NDArray[Any]: ...
@overload
def sum(
x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None
x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
) -> types.DaskArray: ...


Expand All @@ -49,6 +50,7 @@ def sum(
:func:`numpy.sum`

"""
validate_axis(axis)
return _sum(x, axis=axis, dtype=dtype)


Expand Down Expand Up @@ -92,7 +94,7 @@ def sum_drop_keepdims(
a: NDArray[Any] | types.CSBase,
/,
*,
axis: tuple[Literal[0], Literal[1]] | Literal[0, 1] | None = None,
axis: tuple[Literal[0], Literal[1]] | Literal[0, 1, None] = None,
dtype: DTypeLike | None = None,
keepdims: bool = False,
) -> NDArray[Any]:
Expand Down
21 changes: 20 additions & 1 deletion src/testing/fast_array_utils/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


if TYPE_CHECKING:
from collections.abc import Generator
from collections.abc import Callable, Generator

from _pytest.nodes import Node
else:
Expand Down Expand Up @@ -107,3 +107,22 @@ def conversion_context(

with h5py.File(tmp_path, "x") as f:
yield ConversionContext(hdf5_file=f)


@pytest.fixture
def dask_viz(request: pytest.FixtureRequest, cache: pytest.Cache) -> Callable[[object], None]:
def viz(obj: object) -> None:
from fast_array_utils.types import DaskArray

if not isinstance(obj, DaskArray) or not find_spec("ipycytoscape"):
return

if TYPE_CHECKING:
from dask.base import visualize
else:
from dask import visualize

path = cache.mkdir("dask-viz") / cast(Node, request.node).name
visualize(obj, filename=str(path), engine="ipycytoscape")

return viz
55 changes: 51 additions & 4 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@


if TYPE_CHECKING:
from collections.abc import Callable
from typing import Any, Protocol, TypeAlias

from numpy.typing import NDArray
from pytest_codspeed import BenchmarkFixture

from testing.fast_array_utils import ArrayType

DTypeIn = type[np.float32 | np.float64 | np.int32 | np.bool_]
DTypeIn = type[np.float32 | np.float64 | np.int32 | np.bool]
DTypeOut = type[np.float32 | np.float64 | np.int64]

Benchmarkable: TypeAlias = NDArray[Any] | types.CSBase

class BenchFun(Protocol): # noqa: D101
Expand All @@ -40,7 +42,7 @@ def axis(request: pytest.FixtureRequest) -> Literal[0, 1, None]:
return cast(Literal[0, 1, None], request.param)


@pytest.fixture(scope="session", params=[np.float32, np.float64, np.int32, np.bool_])
@pytest.fixture(scope="session", params=[np.float32, np.float64, np.int32, np.bool])
def dtype_in(request: pytest.FixtureRequest) -> DTypeIn:
return cast(DTypeIn, request.param)

Expand Down Expand Up @@ -79,17 +81,62 @@ def test_sum(

if dtype_arg is not None:
assert sum_.dtype == dtype_arg, (sum_.dtype, dtype_arg)
elif dtype_in in {np.bool_, np.int32}:
elif dtype_in in {np.bool, np.int32}:
assert sum_.dtype == np.int64
else:
assert sum_.dtype == dtype_in

np.testing.assert_array_equal(sum_, np.sum(np_arr, axis=axis, dtype=dtype_arg))


@pytest.mark.array_type(skip=Flags.Disk)
@pytest.mark.parametrize(
("axis", "expected"),
[
pytest.param(None, False, id="None"),
pytest.param(0, [True, True, False, False], id="0"),
pytest.param(1, [False, False, True, True, False, True], id="1"),
],
)
def test_is_constant(
array_type: ArrayType, axis: Literal[0, 1, None], expected: bool | list[bool]
) -> None:
x_data = [
[0, 0, 1, 1],
[0, 0, 1, 1],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 0],
]
x = array_type(x_data)
result = stats.is_constant(x, axis=axis)
if isinstance(result, types.DaskArray):
result = result.compute() # type: ignore[no-untyped-call]
if isinstance(expected, list):
np.testing.assert_array_equal(expected, result)
else:
assert expected is result


@pytest.mark.array_type(Flags.Dask)
def test_dask_constant_blocks(
dask_viz: Callable[[object], None], array_type: ArrayType[types.DaskArray, Any]
) -> None:
"""Tests if is_constant works if each chunk is individually constant."""
x_np = np.repeat(np.repeat(np.arange(4).reshape(2, 2), 2, axis=0), 2, axis=1)
x = array_type(x_np)
assert x.blocks.shape == (2, 2)
assert all(stats.is_constant(block).compute() for block in x.blocks.ravel())

result = stats.is_constant(x, axis=None)
dask_viz(result)
assert result.compute() is False # type: ignore[no-untyped-call]


@pytest.mark.benchmark
@pytest.mark.array_type(skip=Flags.Matrix | Flags.Dask | Flags.Disk | Flags.Gpu)
@pytest.mark.parametrize("func", [stats.sum])
@pytest.mark.parametrize("func", [stats.sum, stats.is_constant])
@pytest.mark.parametrize("dtype", [np.float32, np.float64]) # random only supports float
def test_stats_benchmark(
benchmark: BenchmarkFixture,
Expand Down