diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d0e2ef3bd59..029231a3753 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -85,6 +85,11 @@ Breaking changes as positional, all others need to be passed are keyword arguments. This is part of the refactor to support external backends (:issue:`4309`, :pull:`4989`). By `Alessandro Amici `_. +- Functions that are identities for 0d data return the unchanged data + if axis is empty. This ensures that Datasets where some variables do + not have the averaged dimensions are not accidentially changed + (:issue:`4885`, :pull:`5207`). By `David Schwörer + `_ Deprecations ~~~~~~~~~~~~ diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 9dcd7906ef7..8947ecd7477 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -310,13 +310,21 @@ def _ignore_warnings_if(condition): yield -def _create_nan_agg_method(name, dask_module=dask_array, coerce_strings=False): +def _create_nan_agg_method( + name, dask_module=dask_array, coerce_strings=False, invariant_0d=False +): from . import nanops def f(values, axis=None, skipna=None, **kwargs): if kwargs.pop("out", None) is not None: raise TypeError(f"`out` is not valid for {name}") + # The data is invariant in the case of 0d data, so do not + # change the data (and dtype) + # See https://github.com/pydata/xarray/issues/4885 + if invariant_0d and axis == (): + return values + values = asarray(values) if coerce_strings and values.dtype.kind in "SU": @@ -354,28 +362,30 @@ def f(values, axis=None, skipna=None, **kwargs): # See ops.inject_reduce_methods argmax = _create_nan_agg_method("argmax", coerce_strings=True) argmin = _create_nan_agg_method("argmin", coerce_strings=True) -max = _create_nan_agg_method("max", coerce_strings=True) -min = _create_nan_agg_method("min", coerce_strings=True) -sum = _create_nan_agg_method("sum") +max = _create_nan_agg_method("max", coerce_strings=True, invariant_0d=True) +min = _create_nan_agg_method("min", coerce_strings=True, invariant_0d=True) +sum = _create_nan_agg_method("sum", invariant_0d=True) sum.numeric_only = True sum.available_min_count = True std = _create_nan_agg_method("std") std.numeric_only = True var = _create_nan_agg_method("var") var.numeric_only = True -median = _create_nan_agg_method("median", dask_module=dask_array_compat) +median = _create_nan_agg_method( + "median", dask_module=dask_array_compat, invariant_0d=True +) median.numeric_only = True -prod = _create_nan_agg_method("prod") +prod = _create_nan_agg_method("prod", invariant_0d=True) prod.numeric_only = True prod.available_min_count = True -cumprod_1d = _create_nan_agg_method("cumprod") +cumprod_1d = _create_nan_agg_method("cumprod", invariant_0d=True) cumprod_1d.numeric_only = True -cumsum_1d = _create_nan_agg_method("cumsum") +cumsum_1d = _create_nan_agg_method("cumsum", invariant_0d=True) cumsum_1d.numeric_only = True unravel_index = _dask_or_eager_func("unravel_index") -_mean = _create_nan_agg_method("mean") +_mean = _create_nan_agg_method("mean", invariant_0d=True) def _datetime_nanmin(array): diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index 1dd26bab6b6..ef81a6108dd 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -26,7 +26,7 @@ where, ) from xarray.core.pycompat import dask_array_type -from xarray.testing import assert_allclose, assert_equal +from xarray.testing import assert_allclose, assert_equal, assert_identical from . import ( arm_xfail, @@ -373,6 +373,17 @@ def test_cftime_datetime_mean_dask_error(): da.mean() +def test_empty_axis_dtype(): + ds = Dataset() + ds["pos"] = [1, 2, 3] + ds["data"] = ("pos", "time"), [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] + ds["var"] = "pos", [2, 3, 4] + assert_identical(ds.mean(dim="time")["var"], ds["var"]) + assert_identical(ds.max(dim="time")["var"], ds["var"]) + assert_identical(ds.min(dim="time")["var"], ds["var"]) + assert_identical(ds.sum(dim="time")["var"], ds["var"]) + + @pytest.mark.parametrize("dim_num", [1, 2]) @pytest.mark.parametrize("dtype", [float, int, np.float32, np.bool_]) @pytest.mark.parametrize("dask", [False, True])