Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pint: strip, then reattach appropriate units #207

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
73 changes: 62 additions & 11 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import warnings
from functools import partial
from typing import Callable

import numpy as np
import numpy_groupies as npg
Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(
dtypes=None,
final_dtype=None,
reduction_type="reduce",
units_func: Callable | None = None,
):
"""
Blueprint for computing grouped aggregations.
Expand Down Expand Up @@ -156,6 +158,8 @@ def __init__(
per reduction in ``chunk`` as a tuple.
final_dtype : DType, optional
DType for output. By default, uses dtype of array being reduced.
units_func : pint.Unit
units for the output
dcherian marked this conversation as resolved.
Show resolved Hide resolved
"""
self.name = name
# preprocess before blockwise
Expand Down Expand Up @@ -187,6 +191,8 @@ def __init__(
# The following are set by _initialize_aggregation
self.finalize_kwargs = {}
self.min_count = None
self.units_func = units_func
self.units = None

def _normalize_dtype_fill_value(self, value, name):
value = _atleast_1d(value)
Expand Down Expand Up @@ -235,17 +241,44 @@ def __repr__(self):
final_dtype=np.intp,
)


def identity(x):
return x


def square(x):
return x**2


def raise_units_error(x):
raise ValueError(
"Units cannot supported for prod in general. "
"We can only attach units when there are "
"equal number of members in each group. "
"Please strip units and then reattach units "
"to the output manually."
)


# note that the fill values are the result of np.func([np.nan, np.nan])
# final_fill_value is used for groups that don't exist. This is usually np.nan
sum_ = Aggregation("sum", chunk="sum", combine="sum", fill_value=0)
nansum = Aggregation("nansum", chunk="nansum", combine="sum", fill_value=0)
prod = Aggregation("prod", chunk="prod", combine="prod", fill_value=1, final_fill_value=1)
sum_ = Aggregation("sum", chunk="sum", combine="sum", fill_value=0, units_func=identity)
nansum = Aggregation("nansum", chunk="nansum", combine="sum", fill_value=0, units_func=identity)
prod = Aggregation(
"prod",
chunk="prod",
combine="prod",
fill_value=1,
final_fill_value=1,
units_func=raise_units_error,
)
nanprod = Aggregation(
"nanprod",
chunk="nanprod",
combine="prod",
fill_value=1,
final_fill_value=dtypes.NA,
units_func=raise_units_error,
)


Expand All @@ -262,6 +295,7 @@ def _mean_finalize(sum_, count):
fill_value=(0, 0),
dtypes=(None, np.intp),
final_dtype=np.floating,
units_func=identity,
)
nanmean = Aggregation(
"nanmean",
Expand All @@ -271,6 +305,7 @@ def _mean_finalize(sum_, count):
fill_value=(0, 0),
dtypes=(None, np.intp),
final_dtype=np.floating,
units_func=identity,
)


Expand All @@ -296,6 +331,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
final_fill_value=np.nan,
dtypes=(None, None, np.intp),
final_dtype=np.floating,
units_func=square,
)
nanvar = Aggregation(
"nanvar",
Expand All @@ -306,6 +342,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
final_fill_value=np.nan,
dtypes=(None, None, np.intp),
final_dtype=np.floating,
units_func=square,
)
std = Aggregation(
"std",
Expand All @@ -316,6 +353,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
final_fill_value=np.nan,
dtypes=(None, None, np.intp),
final_dtype=np.floating,
units_func=identity,
)
nanstd = Aggregation(
"nanstd",
Expand All @@ -329,10 +367,14 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
)


min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF)
nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan)
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF)
nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan)
min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF, units_func=identity)
nanmin = Aggregation(
"nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan, units_func=identity
)
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF, units_func=identity)
nanmax = Aggregation(
"nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan, units_func=identity
)


def argreduce_preprocess(array, axis):
Expand Down Expand Up @@ -420,10 +462,14 @@ def _pick_second(*x):
final_dtype=np.intp,
)

first = Aggregation("first", chunk=None, combine=None, fill_value=0)
last = Aggregation("last", chunk=None, combine=None, fill_value=0)
nanfirst = Aggregation("nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=np.nan)
nanlast = Aggregation("nanlast", chunk="nanlast", combine="nanlast", fill_value=np.nan)
first = Aggregation("first", chunk=None, combine=None, fill_value=0, units_func=identity)
last = Aggregation("last", chunk=None, combine=None, fill_value=0, units_func=identity)
nanfirst = Aggregation(
"nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=np.nan, units_func=identity
)
nanlast = Aggregation(
"nanlast", chunk="nanlast", combine="nanlast", fill_value=np.nan, units_func=identity
)

all_ = Aggregation(
"all",
Expand Down Expand Up @@ -483,6 +529,7 @@ def _initialize_aggregation(
dtype,
array_dtype,
fill_value,
array_units,
min_count: int | None,
finalize_kwargs,
) -> Aggregation:
Expand Down Expand Up @@ -547,4 +594,8 @@ def _initialize_aggregation(
agg.dtype["intermediate"] += (np.intp,)
agg.dtype["numpy"] += (np.intp,)

if array_units is not None and agg.units_func is not None:
import pint

agg.units = agg.units_func(pint.Quantity([1], units=array_units))
return agg
10 changes: 9 additions & 1 deletion flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
generic_aggregate,
)
from .cache import memoize
from .pint_compat import _reattach_units, _strip_units
from .xrutils import is_duck_array, is_duck_dask_array, isnull

if TYPE_CHECKING:
Expand Down Expand Up @@ -1702,6 +1703,8 @@ def groupby_reduce(
by_is_dask = tuple(is_duck_dask_array(b) for b in bys)
any_by_dask = any(by_is_dask)

array, *bys, units = _strip_units(array, *bys)

if method in ["split-reduce", "cohorts"] and any_by_dask:
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")

Expand Down Expand Up @@ -1803,7 +1806,9 @@ def groupby_reduce(
fill_value = np.nan

kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine)
agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count, finalize_kwargs)
agg = _initialize_aggregation(
func, dtype, array.dtype, fill_value, units[0], min_count, finalize_kwargs
)

if not has_dask:
results = _reduce_blockwise(
Expand Down Expand Up @@ -1862,4 +1867,7 @@ def groupby_reduce(

if _is_minmax_reduction(func) and is_bool_array:
result = result.astype(bool)

units[0] = agg.units
result, *groups = _reattach_units(result, *groups, units=units)
return (result, *groups)
16 changes: 16 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
except ImportError:
xr_types = () # type: ignore

try:
import pint

pint_types = pint.Quantity
except ImportError:
pint_types = () # type: ignore


def _importorskip(modname, minversion=None):
try:
Expand All @@ -46,6 +53,7 @@ def LooseVersion(vstring):


has_dask, requires_dask = _importorskip("dask")
has_pint, requires_pint = _importorskip("pint")
has_xarray, requires_xarray = _importorskip("xarray")


Expand Down Expand Up @@ -95,6 +103,14 @@ def assert_equal(a, b, tolerance=None):
xr.testing.assert_identical(a, b)
return

if has_pint and isinstance(a, pint_types) or isinstance(b, pint_types):
assert isinstance(a, pint_types)
assert isinstance(b, pint_types)
assert a.units == b.units

a = a.magnitude
b = b.magnitude

if tolerance is None and (
np.issubdtype(a.dtype, np.float64) | np.issubdtype(b.dtype, np.float64)
):
Expand Down
37 changes: 37 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
has_dask,
raise_if_dask_computes,
requires_dask,
requires_pint,
)

labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0])
Expand Down Expand Up @@ -1321,3 +1322,39 @@ def test_negative_index_factorize_race_condition():
for f in func
]
[dask.compute(out, scheduler="threads") for _ in range(5)]


@requires_pint
@pytest.mark.parametrize("func", ["all", "count", "sum", "var"])
@pytest.mark.parametrize("chunk", [True, False])
def test_pint(chunk, func):
import pint

if chunk:
d = dask.array.array([1, 2, 3])
else:
d = np.array([1, 2, 3])
q = pint.Quantity(d, units="m")

actual, _ = groupby_reduce(q, [0, 0, 1], func=func)
expected, _ = groupby_reduce(q.magnitude, [0, 0, 1], func=func)

units = None if func in ["count", "all"] else getattr(np, func)(q).units
if units is not None:
expected = pint.Quantity(expected, units=units)
assert_equal(expected, actual)


@requires_pint
@pytest.mark.parametrize("chunk", [True, False])
def test_pint_prod_error(chunk):
import pint

if chunk:
d = dask.array.array([1, 2, 3])
else:
d = np.array([1, 2, 3])
q = pint.Quantity(d, units="m")

with pytest.raises(ValueError):
groupby_reduce(q, [0, 0, 1], func="prod")