|
2 | 2 | from __future__ import annotations |
3 | 3 |
|
4 | 4 | from functools import partial |
5 | | -from typing import TYPE_CHECKING, Literal, cast, get_args |
| 5 | +from typing import TYPE_CHECKING, Literal, TypeVar, cast, get_args |
6 | 6 |
|
7 | 7 | import numpy as np |
8 | 8 | from numpy.exceptions import AxisError |
9 | 9 |
|
10 | 10 | from .. import types |
| 11 | +from ..typing import GpuArray |
11 | 12 | from ._typing import DtypeOps |
12 | 13 |
|
13 | 14 |
|
|
16 | 17 |
|
17 | 18 | from numpy.typing import DTypeLike, NDArray |
18 | 19 |
|
19 | | - from ..typing import CpuArray, GpuArray |
20 | | - from ._typing import Ops |
| 20 | + from ..typing import CpuArray |
| 21 | + from ._typing import DTypeKw, Ops |
21 | 22 |
|
22 | 23 | ComplexAxis: TypeAlias = tuple[Literal[0], Literal[1]] | tuple[Literal[0, 1]] | Literal[0, 1] | None |
23 | 24 |
|
@@ -65,13 +66,17 @@ def _dask_block( |
65 | 66 | axis: ComplexAxis = None, |
66 | 67 | dtype: DTypeLike | None = None, |
67 | 68 | keepdims: bool = False, |
| 69 | + computing_meta: bool = False, |
68 | 70 | ) -> NDArray[Any] | types.CupyArray: |
69 | 71 | from . import max, min, sum |
70 | 72 |
|
| 73 | + if computing_meta: # dask.blockwise doesn’t allow to pass `meta` in, and reductions below don’t handle a 0d matrix |
| 74 | + return (types.CupyArray if isinstance(a, GpuArray) else np.ndarray)((), dtype or a.dtype) |
| 75 | + |
71 | 76 | fns = {fn.__name__: fn for fn in (min, max, sum)} |
72 | 77 |
|
73 | 78 | axis = _normalize_axis(axis, a.ndim) |
74 | | - rv = fns[op](a, axis=axis, dtype=dtype, keep_cupy_as_array=True) # type: ignore[misc,call-overload] |
| 79 | + rv = fns[op](a, axis=axis, keep_cupy_as_array=True, **_dtype_kw(dtype, op)) # type: ignore[call-overload] |
75 | 80 | shape = _get_shape(rv, axis=axis, keepdims=keepdims) |
76 | 81 | return cast("NDArray[Any] | types.CupyArray", rv.reshape(shape)) |
77 | 82 |
|
@@ -105,5 +110,12 @@ def _get_shape(a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Lite |
105 | 110 | assert axis is not None |
106 | 111 | return (1, a.size) if axis == 0 else (a.size, 1) |
107 | 112 | case _: # pragma: no cover |
108 | | - msg = f"{keepdims=}, {type(a)}" |
| 113 | + msg = f"{keepdims=}, {a.ndim=}, {type(a)=}" |
109 | 114 | raise AssertionError(msg) |
| 115 | + |
| 116 | + |
| 117 | +DT = TypeVar("DT", bound="DTypeLike") |
| 118 | + |
| 119 | + |
| 120 | +def _dtype_kw(dtype: DT | None, op: Ops) -> DTypeKw[DT]: |
| 121 | + return {"dtype": dtype} if dtype is not None and op in get_args(DtypeOps) else {} |
0 commit comments