Skip to content

Commit fcebe5e

Browse files
committed
Improvements based on feedback
* Better testing * Clarify comment * Handle other functions as well, like sum, min, max
1 parent a2677ff commit fcebe5e

File tree

3 files changed

+30
-23
lines changed

3 files changed

+30
-23
lines changed

doc/whats-new.rst

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,11 @@ Breaking changes
8585
as positional, all others need to be passed are keyword arguments. This is part of the
8686
refactor to support external backends (:issue:`4309`, :pull:`4989`).
8787
By `Alessandro Amici <https://github.com/alexamici>`_.
88-
- :py:func:`mean` does not change the data if axis is None. This
89-
ensures that Datasets where some variables do not have the averaged
90-
dimensions are not accidentially changed (:issue:`4885`).
91-
By `David Schwörer <https://github.com/dschwoerer>`_
88+
- Functions that are identities for 0d data return the unchanged data
89+
if axis is empty. This ensures that Datasets where some variables do
90+
not have the averaged dimensions are not accidentially changed
91+
(:issue:`4885`, :pull:`5207`). By `David Schwörer
92+
<https://github.com/dschwoerer>`_
9293

9394
Deprecations
9495
~~~~~~~~~~~~

xarray/core/duck_array_ops.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -310,13 +310,21 @@ def _ignore_warnings_if(condition):
310310
yield
311311

312312

313-
def _create_nan_agg_method(name, dask_module=dask_array, coerce_strings=False):
313+
def _create_nan_agg_method(
314+
name, dask_module=dask_array, coerce_strings=False, invariant_0d=False
315+
):
314316
from . import nanops
315317

316318
def f(values, axis=None, skipna=None, **kwargs):
317319
if kwargs.pop("out", None) is not None:
318320
raise TypeError(f"`out` is not valid for {name}")
319321

322+
# The data is invariant in the case of 0d data, so do not
323+
# change the data (and dtype)
324+
# See https://github.com/pydata/xarray/issues/4885
325+
if invariant_0d and axis == ():
326+
return values
327+
320328
values = asarray(values)
321329

322330
if coerce_strings and values.dtype.kind in "SU":
@@ -354,28 +362,30 @@ def f(values, axis=None, skipna=None, **kwargs):
354362
# See ops.inject_reduce_methods
355363
argmax = _create_nan_agg_method("argmax", coerce_strings=True)
356364
argmin = _create_nan_agg_method("argmin", coerce_strings=True)
357-
max = _create_nan_agg_method("max", coerce_strings=True)
358-
min = _create_nan_agg_method("min", coerce_strings=True)
359-
sum = _create_nan_agg_method("sum")
365+
max = _create_nan_agg_method("max", coerce_strings=True, invariant_0d=True)
366+
min = _create_nan_agg_method("min", coerce_strings=True, invariant_0d=True)
367+
sum = _create_nan_agg_method("sum", invariant_0d=True)
360368
sum.numeric_only = True
361369
sum.available_min_count = True
362370
std = _create_nan_agg_method("std")
363371
std.numeric_only = True
364372
var = _create_nan_agg_method("var")
365373
var.numeric_only = True
366-
median = _create_nan_agg_method("median", dask_module=dask_array_compat)
374+
median = _create_nan_agg_method(
375+
"median", dask_module=dask_array_compat, invariant_0d=True
376+
)
367377
median.numeric_only = True
368-
prod = _create_nan_agg_method("prod")
378+
prod = _create_nan_agg_method("prod", invariant_0d=True)
369379
prod.numeric_only = True
370380
prod.available_min_count = True
371-
cumprod_1d = _create_nan_agg_method("cumprod")
381+
cumprod_1d = _create_nan_agg_method("cumprod", invariant_0d=True)
372382
cumprod_1d.numeric_only = True
373-
cumsum_1d = _create_nan_agg_method("cumsum")
383+
cumsum_1d = _create_nan_agg_method("cumsum", invariant_0d=True)
374384
cumsum_1d.numeric_only = True
375385
unravel_index = _dask_or_eager_func("unravel_index")
376386

377387

378-
_mean = _create_nan_agg_method("mean")
388+
_mean = _create_nan_agg_method("mean", invariant_0d=True)
379389

380390

381391
def _datetime_nanmin(array):
@@ -537,11 +547,6 @@ def mean(array, axis=None, skipna=None, **kwargs):
537547
dtypes"""
538548
from .common import _contains_cftime_datetimes
539549

540-
# The mean over an empty axis shouldn't change the data
541-
# See https://github.com/pydata/xarray/issues/4885
542-
if axis == tuple():
543-
return array
544-
545550
array = asarray(array)
546551
if array.dtype.kind in "Mm":
547552
offset = _datetime_nanmin(array)

xarray/tests/test_duck_array_ops.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
where,
2727
)
2828
from xarray.core.pycompat import dask_array_type
29-
from xarray.testing import assert_allclose, assert_equal
29+
from xarray.testing import assert_allclose, assert_equal, assert_identical
3030

3131
from . import (
3232
arm_xfail,
@@ -373,14 +373,15 @@ def test_cftime_datetime_mean_dask_error():
373373
da.mean()
374374

375375

376-
def test_mean_dtype():
376+
def test_empty_axis_dtype():
377377
ds = Dataset()
378378
ds["pos"] = [1, 2, 3]
379379
ds["data"] = ("pos", "time"), [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]
380380
ds["var"] = "pos", [2, 3, 4]
381-
ds2 = ds.mean(dim="time")
382-
assert all(ds2["var"] == ds["var"])
383-
assert ds2["var"].dtype == ds["var"].dtype
381+
assert_identical(ds.mean(dim="time")["var"], ds["var"])
382+
assert_identical(ds.max(dim="time")["var"], ds["var"])
383+
assert_identical(ds.min(dim="time")["var"], ds["var"])
384+
assert_identical(ds.sum(dim="time")["var"], ds["var"])
384385

385386

386387
@pytest.mark.parametrize("dim_num", [1, 2])

0 commit comments

Comments
 (0)