Skip to content

Commit a2677ff

Browse files
committed
Skip mean over empty axis
Avoids changing the datatype if the data does not have the requested axis.
1 parent b2351cb commit a2677ff

File tree

3 files changed

+19
-0
lines changed

3 files changed

+19
-0
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ Breaking changes
8585
as positional, all others need to be passed are keyword arguments. This is part of the
8686
refactor to support external backends (:issue:`4309`, :pull:`4989`).
8787
By `Alessandro Amici <https://github.com/alexamici>`_.
88+
- :py:func:`mean` does not change the data if axis is None. This
89+
ensures that Datasets where some variables do not have the averaged
90+
dimensions are not accidentially changed (:issue:`4885`).
91+
By `David Schwörer <https://github.com/dschwoerer>`_
8892

8993
Deprecations
9094
~~~~~~~~~~~~

xarray/core/duck_array_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,11 @@ def mean(array, axis=None, skipna=None, **kwargs):
537537
dtypes"""
538538
from .common import _contains_cftime_datetimes
539539

540+
# The mean over an empty axis shouldn't change the data
541+
# See https://github.com/pydata/xarray/issues/4885
542+
if axis == tuple():
543+
return array
544+
540545
array = asarray(array)
541546
if array.dtype.kind in "Mm":
542547
offset = _datetime_nanmin(array)

xarray/tests/test_duck_array_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,16 @@ def test_cftime_datetime_mean_dask_error():
373373
da.mean()
374374

375375

376+
def test_mean_dtype():
377+
ds = Dataset()
378+
ds["pos"] = [1, 2, 3]
379+
ds["data"] = ("pos", "time"), [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]
380+
ds["var"] = "pos", [2, 3, 4]
381+
ds2 = ds.mean(dim="time")
382+
assert all(ds2["var"] == ds["var"])
383+
assert ds2["var"].dtype == ds["var"].dtype
384+
385+
376386
@pytest.mark.parametrize("dim_num", [1, 2])
377387
@pytest.mark.parametrize("dtype", [float, int, np.float32, np.bool_])
378388
@pytest.mark.parametrize("dask", [False, True])

0 commit comments

Comments
 (0)