Skip to content

Commit 68b122e

Browse files
authored
Support quantile, median, mode with method="blockwise". (#269)
* Support quantile, median with method="blockwise". We allow method="blockwise" when grouping by a dask array. This can only work if we have expected_groups, and set reindex=True. * Update flox/core.py * fix comment * Update validate_reindex test * Fix * Fix * Raise early if `q` is not provided for quantile * WIP test * narrow type * fix type * Mode + Tests * limit tests * cleanup tests * fix bool tests * Revert "limit tests" This reverts commit d46c3ae. * Small cleanup * more cleanup * fix test * ignore scipy typing * update docs
1 parent 30d522d commit 68b122e

File tree

10 files changed

+261
-62
lines changed

10 files changed

+261
-62
lines changed

ci/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ dependencies:
2222
- pooch
2323
- toolz
2424
- numba
25+
- scipy

docs/source/aggregations.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@ the `func` kwarg:
1111
- `"std"`, `"nanstd"`
1212
- `"argmin"`
1313
- `"argmax"`
14-
- `"first"`
15-
- `"last"`
14+
- `"first"`, `"nanfirst"`
15+
- `"last"`, `"nanlast"`
16+
- `"median"`, `"nanmedian"`
17+
- `"mode"`, `"nanmode"`
18+
- `"quantile"`, `"nanquantile"`
1619

1720
```{tip}
1821
We would like to add support for `cumsum`, `cumprod` ([issue](https://github.com/xarray-contrib/flox/issues/91)). Contributions are welcome!

docs/source/user-stories/custom-aggregations.ipynb

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,13 @@
1515
">\n",
1616
"> A = da.groupby(['lon_bins', 'lat_bins']).mode()\n",
1717
"\n",
18-
"This notebook will describe how to accomplish this using a custom `Aggregation`\n",
19-
"since `mode` and `median` aren't supported by flox yet.\n"
18+
"This notebook will describe how to accomplish this using a custom `Aggregation`.\n",
19+
"\n",
20+
"\n",
21+
"```{tip}\n",
22+
"flox now supports `mode`, `nanmode`, `quantile`, `nanquantile`, `median`, `nanmedian` using exactly the same \n",
23+
"approach as shown below\n",
24+
"```\n"
2025
]
2126
},
2227
{

flox/aggregate_npg.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,84 @@ def _len(group_idx, array, engine, *, func, axis=-1, size=None, fill_value=None,
100100

101101
len = partial(_len, func="len")
102102
nanlen = partial(_len, func="nanlen")
103+
104+
105+
def median(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
106+
return npg.aggregate_numpy.aggregate(
107+
group_idx,
108+
array,
109+
func=np.median,
110+
axis=axis,
111+
size=size,
112+
fill_value=fill_value,
113+
dtype=dtype,
114+
)
115+
116+
117+
def nanmedian(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
118+
return npg.aggregate_numpy.aggregate(
119+
group_idx,
120+
array,
121+
func=np.nanmedian,
122+
axis=axis,
123+
size=size,
124+
fill_value=fill_value,
125+
dtype=dtype,
126+
)
127+
128+
129+
def quantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=None, dtype=None):
130+
return npg.aggregate_numpy.aggregate(
131+
group_idx,
132+
array,
133+
func=partial(np.quantile, q=q),
134+
axis=axis,
135+
size=size,
136+
fill_value=fill_value,
137+
dtype=dtype,
138+
)
139+
140+
141+
def nanquantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=None, dtype=None):
142+
return npg.aggregate_numpy.aggregate(
143+
group_idx,
144+
array,
145+
func=partial(np.nanquantile, q=q),
146+
axis=axis,
147+
size=size,
148+
fill_value=fill_value,
149+
dtype=dtype,
150+
)
151+
152+
153+
def mode_(array, nan_policy, dtype):
154+
from scipy.stats import mode
155+
156+
# npg splits `array` into object arrays for each group
157+
# scipy.stats.mode does not like that
158+
# here we cast back
159+
return mode(array.astype(dtype, copy=False), nan_policy=nan_policy, axis=-1).mode
160+
161+
162+
def mode(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
163+
return npg.aggregate_numpy.aggregate(
164+
group_idx,
165+
array,
166+
func=partial(mode_, nan_policy="propagate", dtype=array.dtype),
167+
axis=axis,
168+
size=size,
169+
fill_value=fill_value,
170+
dtype=dtype,
171+
)
172+
173+
174+
def nanmode(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dtype=None):
175+
return npg.aggregate_numpy.aggregate(
176+
group_idx,
177+
array,
178+
func=partial(mode_, nan_policy="omit", dtype=array.dtype),
179+
axis=axis,
180+
size=size,
181+
fill_value=fill_value,
182+
dtype=dtype,
183+
)

flox/aggregations.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import TYPE_CHECKING, Any, Callable, TypedDict
77

88
import numpy as np
9-
import numpy_groupies as npg
109
from numpy.typing import DTypeLike
1110

1211
from . import aggregate_flox, aggregate_npg, xrutils
@@ -35,6 +34,16 @@ class AggDtype(TypedDict):
3534
intermediate: tuple[np.dtype | type[np.intp], ...]
3635

3736

37+
def get_npg_aggregation(func, *, engine):
38+
try:
39+
method_ = getattr(aggregate_npg, func)
40+
method = partial(method_, engine=engine)
41+
except AttributeError:
42+
aggregate = aggregate_npg._get_aggregate(engine).aggregate
43+
method = partial(aggregate, func=func)
44+
return method
45+
46+
3847
def generic_aggregate(
3948
group_idx,
4049
array,
@@ -51,14 +60,11 @@ def generic_aggregate(
5160
try:
5261
method = getattr(aggregate_flox, func)
5362
except AttributeError:
54-
method = partial(npg.aggregate_numpy.aggregate, func=func)
63+
method = get_npg_aggregation(func, engine="numpy")
64+
5565
elif engine in ["numpy", "numba"]:
56-
try:
57-
method_ = getattr(aggregate_npg, func)
58-
method = partial(method_, engine=engine)
59-
except AttributeError:
60-
aggregate = aggregate_npg._get_aggregate(engine).aggregate
61-
method = partial(aggregate, func=func)
66+
method = get_npg_aggregation(func, engine=engine)
67+
6268
else:
6369
raise ValueError(
6470
f"Expected engine to be one of ['flox', 'numpy', 'numba']. Received {engine} instead."
@@ -465,10 +471,22 @@ def _pick_second(*x):
465471
final_dtype=bool,
466472
)
467473

468-
# numpy_groupies does not support median
469-
# And the dask version is really hard!
470-
# median = Aggregation("median", chunk=None, combine=None, fill_value=None)
471-
# nanmedian = Aggregation("nanmedian", chunk=None, combine=None, fill_value=None)
474+
# Support statistical quantities only blockwise
475+
# The parallel versions will be approximate and are hard to implement!
476+
median = Aggregation(
477+
name="median", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64
478+
)
479+
nanmedian = Aggregation(
480+
name="nanmedian", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64
481+
)
482+
quantile = Aggregation(
483+
name="quantile", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64
484+
)
485+
nanquantile = Aggregation(
486+
name="nanquantile", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64
487+
)
488+
mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None)
489+
nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None)
472490

473491
aggregations = {
474492
"any": any_,
@@ -496,6 +514,12 @@ def _pick_second(*x):
496514
"nanfirst": nanfirst,
497515
"last": last,
498516
"nanlast": nanlast,
517+
"median": median,
518+
"nanmedian": nanmedian,
519+
"quantile": quantile,
520+
"nanquantile": nanquantile,
521+
"mode": mode,
522+
"nanmode": nanmode,
499523
}
500524

501525

flox/core.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,15 +1307,14 @@ def dask_groupby_agg(
13071307
assert isinstance(axis, Sequence)
13081308
assert all(ax >= 0 for ax in axis)
13091309

1310-
if method == "blockwise" and not isinstance(by, np.ndarray):
1311-
raise NotImplementedError
1312-
13131310
inds = tuple(range(array.ndim))
13141311
name = f"groupby_{agg.name}"
13151312
token = dask.base.tokenize(array, by, agg, expected_groups, axis)
13161313

13171314
if expected_groups is None and reindex:
13181315
expected_groups = _get_expected_groups(by, sort=sort)
1316+
if method == "cohorts":
1317+
assert reindex is False
13191318

13201319
by_input = by
13211320

@@ -1349,7 +1348,6 @@ def dask_groupby_agg(
13491348
# b. "_grouped_combine": A more general solution where we tree-reduce the groupby reduction.
13501349
# This allows us to discover groups at compute time, support argreductions, lower intermediate
13511350
# memory usage (but method="cohorts" would also work to reduce memory in some cases)
1352-
13531351
do_simple_combine = not _is_arg_reduction(agg)
13541352

13551353
if method == "blockwise":
@@ -1375,7 +1373,7 @@ def dask_groupby_agg(
13751373
partial(
13761374
blockwise_method,
13771375
axis=axis,
1378-
expected_groups=None if method == "cohorts" else expected_groups,
1376+
expected_groups=expected_groups if reindex else None,
13791377
engine=engine,
13801378
sort=sort,
13811379
),
@@ -1468,14 +1466,24 @@ def dask_groupby_agg(
14681466

14691467
elif method == "blockwise":
14701468
reduced = intermediate
1471-
# Here one input chunk → one output chunks
1472-
# find number of groups in each chunk, this is needed for output chunks
1473-
# along the reduced axis
1474-
slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis))
1475-
groups_in_block = tuple(_unique(by_input[slc]) for slc in slices)
1476-
groups = (np.concatenate(groups_in_block),)
1477-
ngroups_per_block = tuple(len(grp) for grp in groups_in_block)
1478-
group_chunks = (ngroups_per_block,)
1469+
if reindex:
1470+
if TYPE_CHECKING:
1471+
assert expected_groups is not None
1472+
# TODO: we could have `expected_groups` be a dask array with appropriate chunks
1473+
# for now, we have a numpy array that is interpreted as listing all group labels
1474+
# that are present in every chunk
1475+
groups = (expected_groups,)
1476+
group_chunks = ((len(expected_groups),),)
1477+
else:
1478+
# Here one input chunk → one output chunks
1479+
# find number of groups in each chunk, this is needed for output chunks
1480+
# along the reduced axis
1481+
# TODO: this logic is very specialized for the resampling case
1482+
slices = slices_from_chunks(tuple(array.chunks[ax] for ax in axis))
1483+
groups_in_block = tuple(_unique(by_input[slc]) for slc in slices)
1484+
groups = (np.concatenate(groups_in_block),)
1485+
ngroups_per_block = tuple(len(grp) for grp in groups_in_block)
1486+
group_chunks = (ngroups_per_block,)
14791487
else:
14801488
raise ValueError(f"Unknown method={method}.")
14811489

@@ -1547,7 +1555,7 @@ def _validate_reindex(
15471555
if reindex is True and not all_numpy:
15481556
if _is_arg_reduction(func):
15491557
raise NotImplementedError
1550-
if method in ["blockwise", "cohorts"]:
1558+
if method == "cohorts" or (method == "blockwise" and not any_by_dask):
15511559
raise ValueError(
15521560
"reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
15531561
)
@@ -1562,7 +1570,11 @@ def _validate_reindex(
15621570
# have to do the grouped_combine since there's no good fill_value
15631571
reindex = False
15641572

1565-
if method == "blockwise" or _is_arg_reduction(func):
1573+
if method == "blockwise":
1574+
# for grouping by dask arrays, we set reindex=True
1575+
reindex = any_by_dask
1576+
1577+
elif _is_arg_reduction(func):
15661578
reindex = False
15671579

15681580
elif method == "cohorts":
@@ -1767,7 +1779,10 @@ def groupby_reduce(
17671779
*by : ndarray or DaskArray
17681780
Array of labels to group over. Must be aligned with ``array`` so that
17691781
``array.shape[-by.ndim :] == by.shape``
1770-
func : str or Aggregation
1782+
func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \
1783+
"max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \
1784+
"quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \
1785+
"first", "nanfirst", "last", "nanlast"} or Aggregation
17711786
Single function name or an Aggregation instance
17721787
expected_groups : (optional) Sequence
17731788
Expected unique labels.
@@ -1835,7 +1850,7 @@ def groupby_reduce(
18351850
boost in computation speed. For cases like time grouping, this may result in large intermediates relative to the
18361851
original block size. Avoid that by using ``method="cohorts"``. By default, it is turned off for argreductions.
18371852
finalize_kwargs : dict, optional
1838-
Kwargs passed to finalize the reduction such as ``ddof`` for var, std.
1853+
Kwargs passed to finalize the reduction such as ``ddof`` for var, std or ``q`` for quantile.
18391854
18401855
Returns
18411856
-------
@@ -1855,6 +1870,9 @@ def groupby_reduce(
18551870
"Try engine='numpy' or engine='numba' instead."
18561871
)
18571872

1873+
if func == "quantile" and (finalize_kwargs is None or "q" not in finalize_kwargs):
1874+
raise ValueError("Please pass `q` for quantile calculations.")
1875+
18581876
bys: T_Bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by)
18591877
nby = len(bys)
18601878
by_is_dask = tuple(is_duck_dask_array(b) for b in bys)
@@ -2023,7 +2041,7 @@ def groupby_reduce(
20232041
result, groups = partial_agg(
20242042
array,
20252043
by_,
2026-
expected_groups=None if method == "blockwise" else expected_groups,
2044+
expected_groups=expected_groups,
20272045
agg=agg,
20282046
reindex=reindex,
20292047
method=method,

flox/xarray.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,11 @@ def xarray_reduce(
8888
Xarray object to reduce
8989
*by : DataArray or iterable of str or iterable of DataArray
9090
Variables with which to group by ``obj``
91-
func : str or Aggregation
92-
Reduction method
91+
func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \
92+
"max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \
93+
"quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \
94+
"first", "nanfirst", "last", "nanlast"} or Aggregation
95+
Single function name or an Aggregation instance
9396
expected_groups : str or sequence
9497
expected group labels corresponding to each `by` variable
9598
isbin : iterable of bool
@@ -164,7 +167,7 @@ def xarray_reduce(
164167
boost in computation speed. For cases like time grouping, this may result in large intermediates relative to the
165168
original block size. Avoid that by using method="cohorts". By default, it is turned off for arg reductions.
166169
**finalize_kwargs
167-
kwargs passed to the finalize function, like ``ddof`` for var, std.
170+
kwargs passed to the finalize function, like ``ddof`` for var, std or ``q`` for quantile.
168171
169172
Returns
170173
-------

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ module=[
118118
"matplotlib.*",
119119
"pandas",
120120
"setuptools",
121-
"toolz"
121+
"scipy.*",
122+
"toolz",
122123
]
123124
ignore_missing_imports = true
124125

tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def LooseVersion(vstring):
4747

4848
has_dask, requires_dask = _importorskip("dask")
4949
has_numba, requires_numba = _importorskip("numba")
50+
has_scipy, requires_scipy = _importorskip("scipy")
5051
has_xarray, requires_xarray = _importorskip("xarray")
5152

5253

0 commit comments

Comments
 (0)