Skip to content

Commit d459cca

Browse files
authored
Add stats.is_constant (#21)
1 parent ef6b847 commit d459cca

File tree

8 files changed

+254
-9
lines changed

8 files changed

+254
-9
lines changed

docs/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
)
6161
# Try overriding type paths
6262
qualname_overrides = autodoc_type_aliases = {
63+
"np.bool": ("py:data", "numpy.bool"),
6364
"np.dtype": "numpy.dtype",
6465
"np.number": "numpy.number",
6566
"np.integer": "numpy.integer",
@@ -94,6 +95,7 @@
9495
("py:class", "ArrayLike"),
9596
("py:class", "DTypeLike"),
9697
("py:class", "NDArray"),
98+
("py:class", "np.bool"),
9799
("py:class", "_pytest.fixtures.FixtureRequest"),
98100
}
99101

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ scripts.clean = "git clean -fdX docs"
6060

6161
[tool.hatch.envs.hatch-test]
6262
features = [ "test" ]
63-
extra-dependencies = [ "ipykernel" ]
63+
extra-dependencies = [ "ipykernel", "ipycytoscape" ]
6464
env-vars.CODSPEED_PROFILE_FOLDER = "test-data/codspeed"
6565
overrides.matrix.extras.features = [
6666
{ if = [ "full" ], value = "full" },
@@ -106,14 +106,17 @@ lint.flake8-type-checking.strict = true
106106
lint.isort.known-first-party = [ "fast_array_utils" ]
107107
lint.isort.lines-after-imports = 2
108108
lint.isort.required-imports = [ "from __future__ import annotations" ]
109+
lint.pydocstyle.convention = "numpy"
109110

110111
[tool.pytest.ini_options]
111112
addopts = [
112113
"--import-mode=importlib",
113114
"--strict-markers",
115+
"--doctest-modules",
114116
"--pyargs",
115117
"-ptesting.fast_array_utils.pytest",
116118
]
119+
testpaths = [ "./tests", "fast_array_utils" ]
117120
filterwarnings = [
118121
"error",
119122
# codspeed seems to break this dtype added by h5py

src/fast_array_utils/_validation.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
from __future__ import annotations
3+
4+
from numbers import Integral
5+
6+
7+
def validate_axis(axis: int | None) -> None:
8+
if axis is None:
9+
return
10+
if not isinstance(axis, Integral): # pragma: no cover
11+
msg = "axis must be integer or None."
12+
raise TypeError(msg)
13+
if axis not in (0, 1): # pragma: no cover
14+
msg = "We only support axis 0 and 1 at the moment"
15+
raise NotImplementedError(msg)

src/fast_array_utils/stats/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
from __future__ import annotations
55

6+
from ._is_constant import is_constant
67
from ._sum import sum
78

89

9-
__all__ = ["sum"]
10+
__all__ = ["is_constant", "sum"]
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# SPDX-License-Identifier: MPL-2.0
2+
from __future__ import annotations
3+
4+
from functools import partial, singledispatch
5+
from typing import TYPE_CHECKING, Any, cast, overload
6+
7+
import numba
8+
import numpy as np
9+
from numpy.typing import NDArray
10+
11+
from .. import types
12+
from .._validation import validate_axis
13+
14+
15+
if TYPE_CHECKING:
16+
from collections.abc import Callable
17+
from typing import Literal, TypeVar
18+
19+
C = TypeVar("C", bound=Callable[..., Any])
20+
21+
22+
@overload
23+
def is_constant(a: types.DaskArray, /, *, axis: Literal[0, 1, None] = None) -> types.DaskArray: ...
24+
@overload
25+
def is_constant(a: NDArray[Any] | types.CSBase, /, *, axis: None = None) -> bool: ...
26+
@overload
27+
def is_constant(a: NDArray[Any] | types.CSBase, /, *, axis: Literal[0, 1]) -> NDArray[np.bool]: ...
28+
29+
30+
def is_constant(
31+
a: NDArray[Any] | types.CSBase | types.DaskArray, /, *, axis: Literal[0, 1, None] = None
32+
) -> bool | NDArray[np.bool] | types.DaskArray:
33+
"""Check whether values in array are constant.
34+
35+
Params
36+
------
37+
a
38+
Array to check
39+
axis
40+
Axis to reduce over.
41+
42+
Returns
43+
-------
44+
If ``axis`` is :data:`None`, return if all values were constant.
45+
Else returns a boolean array with :data:`True` representing constant columns/rows.
46+
47+
Example
48+
-------
49+
>>> a = np.array([[0, 1], [0, 0]])
50+
>>> a
51+
array([[0, 1],
52+
[0, 0]])
53+
>>> is_constant(a)
54+
False
55+
>>> is_constant(a, axis=0)
56+
array([ True, False])
57+
>>> is_constant(a, axis=1)
58+
array([False, True])
59+
60+
"""
61+
validate_axis(axis)
62+
return _is_constant(a, axis=axis)
63+
64+
65+
@singledispatch
66+
def _is_constant(
67+
a: NDArray[Any] | types.CSBase | types.DaskArray, /, *, axis: Literal[0, 1, None] = None
68+
) -> bool | NDArray[np.bool]: # pragma: no cover
69+
raise NotImplementedError
70+
71+
72+
@_is_constant.register(np.ndarray)
73+
def _(a: NDArray[Any], /, *, axis: Literal[0, 1, None] = None) -> bool | NDArray[np.bool]:
74+
# Should eventually support nd, not now.
75+
match axis:
76+
case None:
77+
return bool((a == a.flat[0]).all())
78+
case 0:
79+
return _is_constant_rows(a.T)
80+
case 1:
81+
return _is_constant_rows(a)
82+
83+
84+
def _is_constant_rows(a: NDArray[Any]) -> NDArray[np.bool]:
85+
b = np.broadcast_to(a[:, 0][:, np.newaxis], a.shape)
86+
return cast(NDArray[np.bool], (a == b).all(axis=1))
87+
88+
89+
@_is_constant.register(types.CSBase)
90+
def _(a: types.CSBase, /, *, axis: Literal[0, 1, None] = None) -> bool | NDArray[np.bool]:
91+
n_row, n_col = a.shape
92+
if axis is None:
93+
if len(a.data) == n_row * n_col:
94+
return is_constant(cast(NDArray[Any], a.data))
95+
return bool((a.data == 0).all())
96+
shape = (n_row, n_col) if axis == 1 else (n_col, n_row)
97+
match axis, a.format:
98+
case 0, "csr":
99+
a = a.T.tocsr()
100+
case 1, "csc":
101+
a = a.T.tocsc()
102+
return _is_constant_csr_rows(a.data, a.indptr, shape)
103+
104+
105+
@numba.njit(cache=True)
106+
def _is_constant_csr_rows(
107+
data: NDArray[np.number[Any]],
108+
indptr: NDArray[np.integer[Any]],
109+
shape: tuple[int, int],
110+
) -> NDArray[np.bool]:
111+
n = len(indptr) - 1
112+
result = np.ones(n, dtype=np.bool)
113+
for i in numba.prange(n):
114+
start = indptr[i]
115+
stop = indptr[i + 1]
116+
val = data[start] if stop - start == shape[1] else 0
117+
for j in range(start, stop):
118+
if data[j] != val:
119+
result[i] = False
120+
break
121+
return result
122+
123+
124+
@_is_constant.register(types.DaskArray)
125+
def _(a: types.DaskArray, /, *, axis: Literal[0, 1, None] = None) -> types.DaskArray:
126+
if TYPE_CHECKING:
127+
from dask.array.core import map_blocks
128+
from dask.array.overlap import map_overlap
129+
else:
130+
from dask.array import map_blocks, map_overlap
131+
132+
if axis is not None:
133+
return cast(
134+
types.DaskArray,
135+
map_blocks( # type: ignore[no-untyped-call]
136+
partial(is_constant, axis=axis), a, drop_axis=axis, meta=np.array([], dtype=np.bool)
137+
),
138+
)
139+
140+
rv = cast(
141+
types.DaskArray,
142+
(a == a[0, 0].compute()).all()
143+
if isinstance(a._meta, np.ndarray) # noqa: SLF001
144+
else map_overlap( # type: ignore[no-untyped-call]
145+
lambda a: np.array([[is_constant(a)]]),
146+
a,
147+
# use asymmetric overlaps to avoid unnecessary computation
148+
depth={d: (0, 1) for d in range(a.ndim)},
149+
trim=False,
150+
meta=np.array([], dtype=bool),
151+
).all(),
152+
)
153+
return cast(
154+
types.DaskArray,
155+
map_blocks(bool, rv, meta=np.array([], dtype=bool)), # type: ignore[no-untyped-call]
156+
)

src/fast_array_utils/stats/_sum.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from numpy.typing import NDArray
99

1010
from .. import types
11+
from .._validation import validate_axis
1112

1213

1314
if TYPE_CHECKING:
@@ -26,7 +27,7 @@ def sum(
2627
) -> NDArray[Any]: ...
2728
@overload
2829
def sum(
29-
x: types.DaskArray, /, *, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None
30+
x: types.DaskArray, /, *, axis: Literal[0, 1, None] = None, dtype: DTypeLike | None = None
3031
) -> types.DaskArray: ...
3132

3233

@@ -49,6 +50,7 @@ def sum(
4950
:func:`numpy.sum`
5051
5152
"""
53+
validate_axis(axis)
5254
return _sum(x, axis=axis, dtype=dtype)
5355

5456

@@ -92,7 +94,7 @@ def sum_drop_keepdims(
9294
a: NDArray[Any] | types.CSBase,
9395
/,
9496
*,
95-
axis: tuple[Literal[0], Literal[1]] | Literal[0, 1] | None = None,
97+
axis: tuple[Literal[0], Literal[1]] | Literal[0, 1, None] = None,
9698
dtype: DTypeLike | None = None,
9799
keepdims: bool = False,
98100
) -> NDArray[Any]:

src/testing/fast_array_utils/pytest.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
if TYPE_CHECKING:
19-
from collections.abc import Generator
19+
from collections.abc import Callable, Generator
2020

2121
from _pytest.nodes import Node
2222
else:
@@ -107,3 +107,22 @@ def conversion_context(
107107

108108
with h5py.File(tmp_path, "x") as f:
109109
yield ConversionContext(hdf5_file=f)
110+
111+
112+
@pytest.fixture
113+
def dask_viz(request: pytest.FixtureRequest, cache: pytest.Cache) -> Callable[[object], None]:
114+
def viz(obj: object) -> None:
115+
from fast_array_utils.types import DaskArray
116+
117+
if not isinstance(obj, DaskArray) or not find_spec("ipycytoscape"):
118+
return
119+
120+
if TYPE_CHECKING:
121+
from dask.base import visualize
122+
else:
123+
from dask import visualize
124+
125+
path = cache.mkdir("dask-viz") / cast(Node, request.node).name
126+
visualize(obj, filename=str(path), engine="ipycytoscape")
127+
128+
return viz

tests/test_stats.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,17 @@
1111

1212

1313
if TYPE_CHECKING:
14+
from collections.abc import Callable
1415
from typing import Any, Protocol, TypeAlias
1516

1617
from numpy.typing import NDArray
1718
from pytest_codspeed import BenchmarkFixture
1819

1920
from testing.fast_array_utils import ArrayType
2021

21-
DTypeIn = type[np.float32 | np.float64 | np.int32 | np.bool_]
22+
DTypeIn = type[np.float32 | np.float64 | np.int32 | np.bool]
2223
DTypeOut = type[np.float32 | np.float64 | np.int64]
24+
2325
Benchmarkable: TypeAlias = NDArray[Any] | types.CSBase
2426

2527
class BenchFun(Protocol): # noqa: D101
@@ -40,7 +42,7 @@ def axis(request: pytest.FixtureRequest) -> Literal[0, 1, None]:
4042
return cast(Literal[0, 1, None], request.param)
4143

4244

43-
@pytest.fixture(scope="session", params=[np.float32, np.float64, np.int32, np.bool_])
45+
@pytest.fixture(scope="session", params=[np.float32, np.float64, np.int32, np.bool])
4446
def dtype_in(request: pytest.FixtureRequest) -> DTypeIn:
4547
return cast(DTypeIn, request.param)
4648

@@ -79,17 +81,62 @@ def test_sum(
7981

8082
if dtype_arg is not None:
8183
assert sum_.dtype == dtype_arg, (sum_.dtype, dtype_arg)
82-
elif dtype_in in {np.bool_, np.int32}:
84+
elif dtype_in in {np.bool, np.int32}:
8385
assert sum_.dtype == np.int64
8486
else:
8587
assert sum_.dtype == dtype_in
8688

8789
np.testing.assert_array_equal(sum_, np.sum(np_arr, axis=axis, dtype=dtype_arg))
8890

8991

92+
@pytest.mark.array_type(skip=Flags.Disk)
93+
@pytest.mark.parametrize(
94+
("axis", "expected"),
95+
[
96+
pytest.param(None, False, id="None"),
97+
pytest.param(0, [True, True, False, False], id="0"),
98+
pytest.param(1, [False, False, True, True, False, True], id="1"),
99+
],
100+
)
101+
def test_is_constant(
102+
array_type: ArrayType, axis: Literal[0, 1, None], expected: bool | list[bool]
103+
) -> None:
104+
x_data = [
105+
[0, 0, 1, 1],
106+
[0, 0, 1, 1],
107+
[0, 0, 0, 0],
108+
[0, 0, 0, 0],
109+
[0, 0, 1, 0],
110+
[0, 0, 0, 0],
111+
]
112+
x = array_type(x_data)
113+
result = stats.is_constant(x, axis=axis)
114+
if isinstance(result, types.DaskArray):
115+
result = result.compute() # type: ignore[no-untyped-call]
116+
if isinstance(expected, list):
117+
np.testing.assert_array_equal(expected, result)
118+
else:
119+
assert expected is result
120+
121+
122+
@pytest.mark.array_type(Flags.Dask)
123+
def test_dask_constant_blocks(
124+
dask_viz: Callable[[object], None], array_type: ArrayType[types.DaskArray, Any]
125+
) -> None:
126+
"""Tests if is_constant works if each chunk is individually constant."""
127+
x_np = np.repeat(np.repeat(np.arange(4).reshape(2, 2), 2, axis=0), 2, axis=1)
128+
x = array_type(x_np)
129+
assert x.blocks.shape == (2, 2)
130+
assert all(stats.is_constant(block).compute() for block in x.blocks.ravel())
131+
132+
result = stats.is_constant(x, axis=None)
133+
dask_viz(result)
134+
assert result.compute() is False # type: ignore[no-untyped-call]
135+
136+
90137
@pytest.mark.benchmark
91138
@pytest.mark.array_type(skip=Flags.Matrix | Flags.Dask | Flags.Disk | Flags.Gpu)
92-
@pytest.mark.parametrize("func", [stats.sum])
139+
@pytest.mark.parametrize("func", [stats.sum, stats.is_constant])
93140
@pytest.mark.parametrize("dtype", [np.float32, np.float64]) # random only supports float
94141
def test_stats_benchmark(
95142
benchmark: BenchmarkFixture,

0 commit comments

Comments
 (0)