Skip to content

Commit da8e85a

Browse files
fix: support Dask and cupy/scipy sparse matrices in min/max (#135)
Co-authored-by: Phil Schaf <flying-sheep@web.de>
1 parent 729d35f commit da8e85a

File tree

7 files changed

+65
-20
lines changed

7 files changed

+65
-20
lines changed

src/fast_array_utils/stats/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def _generic_op(
223223
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
224224
from ._generic_ops import generic_op
225225

226-
assert dtype is None or op in get_args(DtypeOps), f"`dtype` is not supported for operation '{op}'"
226+
assert dtype is None or op in get_args(DtypeOps), f"`dtype` is not supported for operation {op!r}"
227227

228228
validate_axis(x.ndim, axis)
229229
return generic_op(x, op, axis=axis, keep_cupy_as_array=keep_cupy_as_array, dtype=dtype)

src/fast_array_utils/stats/_generic_ops.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from .. import types
1010
from ._typing import DtypeOps
11-
from ._utils import _dask_inner
11+
from ._utils import _dask_inner, _dtype_kw
1212

1313

1414
if TYPE_CHECKING:
@@ -29,8 +29,8 @@ def _run_numpy_op(
2929
axis: Literal[0, 1] | None = None,
3030
dtype: DTypeLike | None = None,
3131
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
32-
kwargs = {"dtype": dtype} if op in get_args(DtypeOps) else {}
33-
return getattr(np, op)(x, axis=axis, **kwargs) # type: ignore[no-any-return]
32+
arr = cast("NDArray[Any] | np.number[Any] | types.CupyArray | types.CupyCOOMatrix | types.DaskArray", getattr(np, op)(x, axis=axis, **_dtype_kw(dtype, op)))
33+
return arr.toarray() if isinstance(arr, types.CupyCOOMatrix) else arr
3434

3535

3636
@singledispatch
@@ -83,14 +83,15 @@ def _generic_op_cs(
8383
# just convert to sparse array, then `return x.{op}(dtype=dtype)`
8484
# https://github.com/scipy/scipy/issues/23768
8585

86-
kwargs = {"dtype": dtype} if op in get_args(DtypeOps) else {}
8786
if axis is None:
88-
return cast("np.number[Any]", getattr(x.data, op)(**kwargs))
87+
return cast("np.number[Any]", getattr(x.data, op)(**_dtype_kw(dtype, op)))
8988
if TYPE_CHECKING: # scipy-stubs thinks e.g. "int64" is invalid, which isn’t true
9089
assert isinstance(dtype, np.dtype | type | None)
9190
# convert to array so dimensions collapse as expected
92-
x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **kwargs) # type: ignore[call-overload]
93-
return cast("NDArray[Any] | np.number[Any]", getattr(x, op)(axis=axis))
91+
x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **_dtype_kw(dtype, op)) # type: ignore[arg-type]
92+
rv = cast("NDArray[Any] | types.coo_array | np.number[Any]", getattr(x, op)(axis=axis))
93+
# old scipy versions’ sparray.{max,min}() return a 1×n/n×1 sparray here, so we squeeze
94+
return rv.toarray().squeeze() if isinstance(rv, types.coo_array) else rv
9495

9596

9697
@generic_op.register(types.DaskArray)

src/fast_array_utils/stats/_typing.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: MPL-2.0
22
from __future__ import annotations
33

4-
from typing import TYPE_CHECKING, Literal, Protocol
4+
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypedDict, TypeVar
55

66
import numpy as np
77

@@ -49,3 +49,10 @@ def __call__(
4949
NoDtypeOps = Literal["max", "min"]
5050
DtypeOps = Literal["sum"]
5151
Ops: TypeAlias = NoDtypeOps | DtypeOps
52+
53+
54+
_DT = TypeVar("_DT", bound="DTypeLike")
55+
56+
57+
class DTypeKw(TypedDict, Generic[_DT], total=False):
58+
dtype: _DT

src/fast_array_utils/stats/_utils.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from __future__ import annotations
33

44
from functools import partial
5-
from typing import TYPE_CHECKING, Literal, cast, get_args
5+
from typing import TYPE_CHECKING, Literal, TypeVar, cast, get_args
66

77
import numpy as np
88
from numpy.exceptions import AxisError
99

1010
from .. import types
11+
from ..typing import GpuArray
1112
from ._typing import DtypeOps
1213

1314

@@ -16,8 +17,8 @@
1617

1718
from numpy.typing import DTypeLike, NDArray
1819

19-
from ..typing import CpuArray, GpuArray
20-
from ._typing import Ops
20+
from ..typing import CpuArray
21+
from ._typing import DTypeKw, Ops
2122

2223
ComplexAxis: TypeAlias = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1] | None
2324

@@ -65,13 +66,17 @@ def _dask_block(
6566
axis: ComplexAxis = None,
6667
dtype: DTypeLike | None = None,
6768
keepdims: bool = False,
69+
computing_meta: bool = False,
6870
) -> NDArray[Any] | types.CupyArray:
6971
from . import max, min, sum
7072

73+
if computing_meta: # dask.blockwise doesn’t allow to pass `meta` in, and reductions below don’t handle a 0d matrix
74+
return (types.CupyArray if isinstance(a, GpuArray) else np.ndarray)((), dtype or a.dtype)
75+
7176
fns = {fn.__name__: fn for fn in (min, max, sum)}
7277

7378
axis = _normalize_axis(axis, a.ndim)
74-
rv = fns[op](a, axis=axis, dtype=dtype, keep_cupy_as_array=True) # type: ignore[misc,call-overload]
79+
rv = fns[op](a, axis=axis, keep_cupy_as_array=True, **_dtype_kw(dtype, op)) # type: ignore[call-overload]
7580
shape = _get_shape(rv, axis=axis, keepdims=keepdims)
7681
return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape))
7782

@@ -105,5 +110,12 @@ def _get_shape(a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Lite
105110
assert axis is not None
106111
return (1, a.size) if axis == 0 else (a.size, 1)
107112
case _: # pragma: no cover
108-
msg = f"{keepdims=}, {type(a)}"
113+
msg = f"{keepdims=}, {a.ndim=}, {type(a)=}"
109114
raise AssertionError(msg)
115+
116+
117+
DT = TypeVar("DT", bound="DTypeLike")
118+
119+
120+
def _dtype_kw(dtype: DT | None, op: Ops) -> DTypeKw[DT]:
121+
return {"dtype": dtype} if dtype is not None and op in get_args(DtypeOps) else {}

tests/test_stats.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,20 @@ def test_sum_to_int(array_type: ArrayType[CpuArray | DiskArray | types.DaskArray
201201
np.testing.assert_array_equal(sum_, expected)
202202

203203

204+
@pytest.mark.array_type(skip=ATS_SPARSE_DS)
205+
@pytest.mark.parametrize("func", [stats.min, stats.max])
206+
def test_min_max(array_type: ArrayType[CpuArray | GpuArray | DiskArray | types.DaskArray], axis: Literal[0, 1] | None, func: StatFunNoDtype) -> None:
207+
rng = np.random.default_rng(0)
208+
np_arr = rng.random((100, 100))
209+
arr = array_type(np_arr)
210+
211+
result = to_np_dense_checked(func(arr, axis=axis), axis, arr)
212+
213+
expected = (np.min if func is stats.min else np.max)(np_arr, axis=axis)
214+
np.testing.assert_array_equal(result, expected)
215+
216+
217+
@pytest.mark.parametrize("func", [stats.sum, stats.min, stats.max])
204218
@pytest.mark.parametrize(
205219
"data",
206220
[
@@ -211,14 +225,15 @@ def test_sum_to_int(array_type: ArrayType[CpuArray | DiskArray | types.DaskArray
211225
)
212226
@pytest.mark.parametrize("axis", [0, 1])
213227
@pytest.mark.array_type(Flags.Dask)
214-
def test_sum_dask_shapes(array_type: ArrayType[types.DaskArray], axis: Literal[0, 1], data: list[list[int]]) -> None:
228+
def test_dask_shapes(array_type: ArrayType[types.DaskArray], axis: Literal[0, 1], data: list[list[int]], func: StatFunNoDtype) -> None:
215229
np_arr = np.array(data, dtype=np.float32)
216230
arr = array_type(np_arr)
217231
assert 1 in arr.chunksize, "This test is supposed to test 1×n and n×1 chunk sizes"
218-
sum_ = cast("NDArray[Any] | types.CupyArray", stats.sum(arr, axis=axis).compute())
219-
if isinstance(sum_, types.CupyArray):
220-
sum_ = sum_.get()
221-
np.testing.assert_almost_equal(np_arr.sum(axis=axis), sum_)
232+
stat = cast("NDArray[Any] | types.CupyArray", func(arr, axis=axis).compute())
233+
if isinstance(stat, types.CupyArray):
234+
stat = stat.get()
235+
np_func = getattr(np, func.__name__)
236+
np.testing.assert_almost_equal(stat, np_func(np_arr, axis=axis))
222237

223238

224239
@pytest.mark.array_type(skip=ATS_SPARSE_DS)

typings/cupy/_core/core.pyi

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ from types import EllipsisType
33
from typing import Any, Literal, Self, overload
44

55
import numpy as np
6-
from cupy.cuda import Stream
6+
from cupy.cuda import MemoryPointer, Stream
77
from numpy._core.multiarray import flagsobj
88
from numpy.typing import DTypeLike, NDArray
99

@@ -14,6 +14,15 @@ class ndarray:
1414
ndim: int
1515
flags: flagsobj
1616

17+
def __init__(
18+
self,
19+
shape: tuple[int, ...],
20+
dtype: DTypeLike | None = ...,
21+
memptr: MemoryPointer | None = None,
22+
strides: tuple[int, ...] | None = None,
23+
order: Literal["C", "F"] = "C",
24+
) -> None: ...
25+
1726
# cupy-specific
1827
def get(
1928
self, stream: Stream | None = None, order: Literal["C", "F", "A"] = "C", out: NDArray[Any] | None = None, blocking: bool = True

typings/cupy/cuda.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
# SPDX-License-Identifier: MPL-2.0
22
class Stream: ...
3+
class MemoryPointer: ...

0 commit comments

Comments
 (0)