Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pint: strip, then reattach appropriate units #207

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
- numpy>=1.20
- lxml # for mypy coverage report
- matplotlib
- pint
- pip
- pytest
- pytest-cov
Expand Down
75 changes: 62 additions & 13 deletions flox/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ def generic_aggregate(
f"Expected engine to be one of ['flox', 'numpy', 'numba']. Received {engine} instead."
)

group_idx = np.asarray(group_idx, like=array)

with warnings.catch_warnings():
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
result = method(
Expand Down Expand Up @@ -131,6 +129,7 @@ def __init__(
dtypes=None,
final_dtype: DTypeLike | None = None,
reduction_type="reduce",
units_func: Callable | None = None,
):
"""
Blueprint for computing grouped aggregations.
Expand Down Expand Up @@ -173,6 +172,8 @@ def __init__(
per reduction in ``chunk`` as a tuple.
final_dtype : DType, optional
DType for output. By default, uses dtype of array being reduced.
units_func : callable
function whose output will be used to infer units.
"""
self.name = name
# preprocess before blockwise
Expand Down Expand Up @@ -206,6 +207,8 @@ def __init__(
# The following are set by _initialize_aggregation
self.finalize_kwargs: dict[Any, Any] = {}
self.min_count: int | None = None
self.units_func: Callable = units_func
self.units = None

def _normalize_dtype_fill_value(self, value, name):
value = _atleast_1d(value)
Expand Down Expand Up @@ -254,17 +257,44 @@ def __repr__(self) -> str:
final_dtype=np.intp,
)


def identity(x):
return x


def square(x):
return x**2


def raise_units_error(x):
raise ValueError(
"Units cannot supported for prod in general. "
"We can only attach units when there are "
"equal number of members in each group. "
"Please strip units and then reattach units "
"to the output manually."
)


# note that the fill values are the result of np.func([np.nan, np.nan])
# final_fill_value is used for groups that don't exist. This is usually np.nan
sum_ = Aggregation("sum", chunk="sum", combine="sum", fill_value=0)
nansum = Aggregation("nansum", chunk="nansum", combine="sum", fill_value=0)
prod = Aggregation("prod", chunk="prod", combine="prod", fill_value=1, final_fill_value=1)
sum_ = Aggregation("sum", chunk="sum", combine="sum", fill_value=0, units_func=identity)
nansum = Aggregation("nansum", chunk="nansum", combine="sum", fill_value=0, units_func=identity)
prod = Aggregation(
"prod",
chunk="prod",
combine="prod",
fill_value=1,
final_fill_value=1,
units_func=raise_units_error,
)
nanprod = Aggregation(
"nanprod",
chunk="nanprod",
combine="prod",
fill_value=1,
final_fill_value=dtypes.NA,
units_func=raise_units_error,
)


Expand All @@ -281,6 +311,7 @@ def _mean_finalize(sum_, count):
fill_value=(0, 0),
dtypes=(None, np.intp),
final_dtype=np.floating,
units_func=identity,
)
nanmean = Aggregation(
"nanmean",
Expand All @@ -290,6 +321,7 @@ def _mean_finalize(sum_, count):
fill_value=(0, 0),
dtypes=(None, np.intp),
final_dtype=np.floating,
units_func=identity,
)


Expand All @@ -315,6 +347,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
final_fill_value=np.nan,
dtypes=(None, None, np.intp),
final_dtype=np.floating,
units_func=square,
)
nanvar = Aggregation(
"nanvar",
Expand All @@ -325,6 +358,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
final_fill_value=np.nan,
dtypes=(None, None, np.intp),
final_dtype=np.floating,
units_func=square,
)
std = Aggregation(
"std",
Expand All @@ -335,6 +369,7 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
final_fill_value=np.nan,
dtypes=(None, None, np.intp),
final_dtype=np.floating,
units_func=identity,
)
nanstd = Aggregation(
"nanstd",
Expand All @@ -345,13 +380,18 @@ def _std_finalize(sumsq, sum_, count, ddof=0):
final_fill_value=np.nan,
dtypes=(None, None, np.intp),
final_dtype=np.floating,
units_func=identity,
)


min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF)
nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan)
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF)
nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan)
min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF, units_func=identity)
nanmin = Aggregation(
"nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan, units_func=identity
)
max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF, units_func=identity)
nanmax = Aggregation(
"nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan, units_func=identity
)


def argreduce_preprocess(array, axis):
Expand Down Expand Up @@ -439,10 +479,14 @@ def _pick_second(*x):
final_dtype=np.intp,
)

first = Aggregation("first", chunk=None, combine=None, fill_value=0)
last = Aggregation("last", chunk=None, combine=None, fill_value=0)
nanfirst = Aggregation("nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=np.nan)
nanlast = Aggregation("nanlast", chunk="nanlast", combine="nanlast", fill_value=np.nan)
first = Aggregation("first", chunk=None, combine=None, fill_value=0, units_func=identity)
last = Aggregation("last", chunk=None, combine=None, fill_value=0, units_func=identity)
nanfirst = Aggregation(
"nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=np.nan, units_func=identity
)
nanlast = Aggregation(
"nanlast", chunk="nanlast", combine="nanlast", fill_value=np.nan, units_func=identity
)

all_ = Aggregation(
"all",
Expand Down Expand Up @@ -502,6 +546,7 @@ def _initialize_aggregation(
dtype,
array_dtype,
fill_value,
array_units,
min_count: int | None,
finalize_kwargs: dict[Any, Any] | None,
) -> Aggregation:
Expand Down Expand Up @@ -572,4 +617,8 @@ def _initialize_aggregation(
agg.dtype["intermediate"] += (np.intp,)
agg.dtype["numpy"] += (np.intp,)

if array_units is not None and agg.units_func is not None:
import pint

agg.units = agg.units_func(pint.Quantity([1], units=array_units))
return agg
10 changes: 9 additions & 1 deletion flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
generic_aggregate,
)
from .cache import memoize
from .pint_compat import _reattach_units, _strip_units
from .xrutils import is_duck_array, is_duck_dask_array, isnull

if TYPE_CHECKING:
Expand Down Expand Up @@ -1799,6 +1800,8 @@ def groupby_reduce(
by_is_dask = tuple(is_duck_dask_array(b) for b in bys)
any_by_dask = any(by_is_dask)

array, bys, units = _strip_units(array, *bys)

if method in ["split-reduce", "cohorts"] and any_by_dask:
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")

Expand Down Expand Up @@ -1904,7 +1907,9 @@ def groupby_reduce(
fill_value = np.nan

kwargs = dict(axis=axis_, fill_value=fill_value, engine=engine)
agg = _initialize_aggregation(func, dtype, array.dtype, fill_value, min_count, finalize_kwargs)
agg = _initialize_aggregation(
func, dtype, array.dtype, fill_value, units[0], min_count, finalize_kwargs
)

groups: tuple[np.ndarray | DaskArray, ...]
if not has_dask:
Expand Down Expand Up @@ -1964,4 +1969,7 @@ def groupby_reduce(

if _is_minmax_reduction(func) and is_bool_array:
result = result.astype(bool)

units[0] = agg.units
result, *groups = _reattach_units(result, *groups, units=units)
return (result, *groups) # type: ignore[return-value] # Unpack not in mypy yet
25 changes: 25 additions & 0 deletions flox/pint_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
def _strip_units(*arrays):
try:
import pint

pint_quantity = (pint.Quantity,)

except ImportError:
pint_quantity = ()

bare = tuple(array.magnitude if isinstance(array, pint_quantity) else array for array in arrays)
units = [array.units if isinstance(array, pint_quantity) else None for array in arrays]

return bare[0], bare[1:], units


def _reattach_units(*arrays, units):
try:
import pint

return tuple(
pint.Quantity(array, unit) if unit is not None else array
for array, unit in zip(arrays, units)
)
except ImportError:
return arrays
16 changes: 16 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
except ImportError:
xr_types = () # type: ignore

try:
import pint

pint_types = pint.Quantity
except ImportError:
pint_types = () # type: ignore


def _importorskip(modname, minversion=None):
try:
Expand All @@ -46,6 +53,7 @@ def LooseVersion(vstring):


has_dask, requires_dask = _importorskip("dask")
has_pint, requires_pint = _importorskip("pint")
has_xarray, requires_xarray = _importorskip("xarray")


Expand Down Expand Up @@ -95,6 +103,14 @@ def assert_equal(a, b, tolerance=None):
xr.testing.assert_identical(a, b)
return

if has_pint and isinstance(a, pint_types) or isinstance(b, pint_types):
assert isinstance(a, pint_types)
assert isinstance(b, pint_types)
assert a.units == b.units

a = a.magnitude
b = b.magnitude

if tolerance is None and (
np.issubdtype(a.dtype, np.float64) | np.issubdtype(b.dtype, np.float64)
):
Expand Down
40 changes: 40 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
has_dask,
raise_if_dask_computes,
requires_dask,
requires_pint,
)

labels = np.array([0, 0, 2, 2, 2, 1, 1, 2, 2, 1, 1, 0])
Expand Down Expand Up @@ -1339,6 +1340,45 @@ def test_negative_index_factorize_race_condition():
[dask.compute(out, scheduler="threads") for _ in range(5)]


@requires_pint
@pytest.mark.parametrize("func", ALL_FUNCS)
@pytest.mark.parametrize("chunk", [True, False])
def test_pint(chunk, func, engine):
import pint

if func in ["prod", "nanprod"]:
pytest.skip()

if chunk:
d = dask.array.array([1, 2, 3])
else:
d = np.array([1, 2, 3])
q = pint.Quantity(d, units="m")

actual, _ = groupby_reduce(q, [0, 0, 1], func=func)
expected, _ = groupby_reduce(q.magnitude, [0, 0, 1], func=func)

units = None if func in ["count", "all", "any"] or "arg" in func else getattr(np, func)(q).units
if units is not None:
expected = pint.Quantity(expected, units=units)
assert_equal(expected, actual)


@requires_pint
@pytest.mark.parametrize("chunk", [True, False])
def test_pint_prod_error(chunk):
import pint

if chunk:
d = dask.array.array([1, 2, 3])
else:
d = np.array([1, 2, 3])
q = pint.Quantity(d, units="m")

with pytest.raises(ValueError):
groupby_reduce(q, [0, 0, 1], func="prod")


@pytest.mark.parametrize("sort", [True, False])
def test_expected_index_conversion_passthrough_range_index(sort):
index = pd.RangeIndex(100)
Expand Down