From d98baf454c348b5cd39c805d00998a766accea27 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Tue, 21 Nov 2023 10:24:51 -0800 Subject: [PATCH 01/14] Consolidate `_get_alpha` func (#8465) * Consolidate `_get_alpha` func Am changing this a bit so starting with consolidating it rather than converting twice --- xarray/core/rolling_exp.py | 41 +++++++++++--------------------------- 1 file changed, 12 insertions(+), 29 deletions(-) diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index c8160cefef3..04d7dd41966 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -25,51 +25,34 @@ def _get_alpha( span: float | None = None, halflife: float | None = None, alpha: float | None = None, -) -> float: - # pandas defines in terms of com (converting to alpha in the algo) - # so use its function to get a com and then convert to alpha - - com = _get_center_of_mass(com, span, halflife, alpha) - return 1 / (1 + com) - - -def _get_center_of_mass( - comass: float | None, - span: float | None, - halflife: float | None, - alpha: float | None, ) -> float: """ - Vendored from pandas.core.window.common._get_center_of_mass - - See licenses/PANDAS_LICENSE for the function's license + Convert com, span, halflife to alpha. """ - valid_count = count_not_none(comass, span, halflife, alpha) + valid_count = count_not_none(com, span, halflife, alpha) if valid_count > 1: - raise ValueError("comass, span, halflife, and alpha are mutually exclusive") + raise ValueError("com, span, halflife, and alpha are mutually exclusive") - # Convert to center of mass; domain checks ensure 0 < alpha <= 1 - if comass is not None: - if comass < 0: - raise ValueError("comass must satisfy: comass >= 0") + # Convert to alpha + if com is not None: + if com < 0: + raise ValueError("commust satisfy: com>= 0") + return 1 / (com + 1) elif span is not None: if span < 1: raise ValueError("span must satisfy: span >= 1") - comass = (span - 1) / 2.0 + return 2 / (span + 1) elif halflife is not None: if halflife <= 0: raise ValueError("halflife must satisfy: halflife > 0") - decay = 1 - np.exp(np.log(0.5) / halflife) - comass = 1 / decay - 1 + return 1 - np.exp(np.log(0.5) / halflife) elif alpha is not None: - if alpha <= 0 or alpha > 1: + if not 0 < alpha <= 1: raise ValueError("alpha must satisfy: 0 < alpha <= 1") - comass = (1.0 - alpha) / alpha + return alpha else: raise ValueError("Must pass one of comass, span, halflife, or alpha") - return float(comass) - class RollingExp(Generic[T_DataWithCoords]): """ From 7e6eba06d9397f9c70c9252bc9ba4efe6e7a846a Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Tue, 21 Nov 2023 10:25:14 -0800 Subject: [PATCH 02/14] Fix `map_blocks` docs' formatting (#8464) --- xarray/core/parallel.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index dd5232023a2..f971556b3f7 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -186,8 +186,9 @@ def map_blocks( Returns ------- - A single DataArray or Dataset with dask backend, reassembled from the outputs of the - function. + obj : same as obj + A single DataArray or Dataset with dask backend, reassembled from the outputs of the + function. Notes ----- From dcf5d743fc8ff66996ff73c08f6893b701ff6e02 Mon Sep 17 00:00:00 2001 From: Illviljan <14371165+Illviljan@users.noreply.github.com> Date: Tue, 21 Nov 2023 20:26:24 +0100 Subject: [PATCH 03/14] Use concise date format when plotting (#8449) * Add concise date format * Update utils.py * Update dataarray_plot.py * Update dataarray_plot.py * Update whats-new.rst * Cleanup * Clarify xfail reason * Update whats-new.rst --- doc/whats-new.rst | 2 ++ xarray/plot/dataarray_plot.py | 35 +++++--------------- xarray/plot/utils.py | 26 ++++++++++++++- xarray/tests/test_plot.py | 62 +++++++++++++++++++++++++++++++---- 4 files changed, 90 insertions(+), 35 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 350cc2e0efa..3698058cfe8 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,6 +23,8 @@ v2023.11.1 (unreleased) New Features ~~~~~~~~~~~~ +- Use a concise format when plotting datetime arrays. (:pull:`8449`). + By `Jimmy Westling `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index 61f2014fbc3..6da97a3faf0 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -27,6 +27,7 @@ _rescale_imshow_rgb, _resolve_intervals_1dplot, _resolve_intervals_2dplot, + _set_concise_date, _update_axes, get_axis, label_from_attrs, @@ -525,14 +526,8 @@ def line( assert hueplt is not None ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label) - # Rotate dates on xlabels - # Do this without calling autofmt_xdate so that x-axes ticks - # on other subplots (if any) are not deleted. - # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots if np.issubdtype(xplt.dtype, np.datetime64): - for xlabels in ax.get_xticklabels(): - xlabels.set_rotation(30) - xlabels.set_horizontalalignment("right") + _set_concise_date(ax, axis="x") _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) @@ -1087,14 +1082,12 @@ def _add_labels( add_labels: bool | Iterable[bool], darrays: Iterable[DataArray | None], suffixes: Iterable[str], - rotate_labels: Iterable[bool], ax: Axes, ) -> None: """Set x, y, z labels.""" add_labels = [add_labels] * 3 if isinstance(add_labels, bool) else add_labels - for axis, add_label, darray, suffix, rotate_label in zip( - ("x", "y", "z"), add_labels, darrays, suffixes, rotate_labels - ): + axes: tuple[Literal["x", "y", "z"], ...] = ("x", "y", "z") + for axis, add_label, darray, suffix in zip(axes, add_labels, darrays, suffixes): if darray is None: continue @@ -1103,14 +1096,8 @@ def _add_labels( if label is not None: getattr(ax, f"set_{axis}label")(label) - if rotate_label and np.issubdtype(darray.dtype, np.datetime64): - # Rotate dates on xlabels - # Do this without calling autofmt_xdate so that x-axes ticks - # on other subplots (if any) are not deleted. - # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots - for labels in getattr(ax, f"get_{axis}ticklabels")(): - labels.set_rotation(30) - labels.set_horizontalalignment("right") + if np.issubdtype(darray.dtype, np.datetime64): + _set_concise_date(ax, axis=axis) @overload @@ -1265,7 +1252,7 @@ def scatter( kwargs.update(s=sizeplt.to_numpy().ravel()) plts_or_none = (xplt, yplt, zplt) - _add_labels(add_labels, plts_or_none, ("", "", ""), (True, False, False), ax) + _add_labels(add_labels, plts_or_none, ("", "", ""), ax) xplt_np = None if xplt is None else xplt.to_numpy().ravel() yplt_np = None if yplt is None else yplt.to_numpy().ravel() @@ -1653,14 +1640,8 @@ def newplotfunc( ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim ) - # Rotate dates on xlabels - # Do this without calling autofmt_xdate so that x-axes ticks - # on other subplots (if any) are not deleted. - # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots if np.issubdtype(xplt.dtype, np.datetime64): - for xlabels in ax.get_xticklabels(): - xlabels.set_rotation(30) - xlabels.set_horizontalalignment("right") + _set_concise_date(ax, "x") return primitive diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 5694acc06e8..903780b1137 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -6,7 +6,7 @@ from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence from datetime import datetime from inspect import getfullargspec -from typing import TYPE_CHECKING, Any, Callable, overload +from typing import TYPE_CHECKING, Any, Callable, Literal, overload import numpy as np import pandas as pd @@ -1827,3 +1827,27 @@ def _guess_coords_to_plot( _assert_valid_xy(darray, dim, k) return coords_to_plot + + +def _set_concise_date(ax: Axes, axis: Literal["x", "y", "z"] = "x") -> None: + """ + Use ConciseDateFormatter which is meant to improve the + strings chosen for the ticklabels, and to minimize the + strings used in those tick labels as much as possible. + + https://matplotlib.org/stable/gallery/ticks/date_concise_formatter.html + + Parameters + ---------- + ax : Axes + Figure axes. + axis : Literal["x", "y", "z"], optional + Which axis to make concise. The default is "x". + """ + import matplotlib.dates as mdates + + locator = mdates.AutoDateLocator() + formatter = mdates.ConciseDateFormatter(locator) + _axis = getattr(ax, f"{axis}axis") + _axis.set_major_locator(locator) + _axis.set_major_formatter(formatter) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 31c23955b02..102d06b0289 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -787,12 +787,17 @@ def test_plot_nans(self) -> None: self.darray[1] = np.nan self.darray.plot.line() - def test_x_ticks_are_rotated_for_time(self) -> None: + def test_dates_are_concise(self) -> None: + import matplotlib.dates as mdates + time = pd.date_range("2000-01-01", "2000-01-10") a = DataArray(np.arange(len(time)), [("t", time)]) a.plot.line() - rotation = plt.gca().get_xticklabels()[0].get_rotation() - assert rotation != 0 + + ax = plt.gca() + + assert isinstance(ax.xaxis.get_major_locator(), mdates.AutoDateLocator) + assert isinstance(ax.xaxis.get_major_formatter(), mdates.ConciseDateFormatter) def test_xyincrease_false_changes_axes(self) -> None: self.darray.plot.line(xincrease=False, yincrease=False) @@ -1356,12 +1361,17 @@ def test_xyincrease_true_changes_axes(self) -> None: diffs = xlim[0] - 0, xlim[1] - 14, ylim[0] - 0, ylim[1] - 9 assert all(abs(x) < 1 for x in diffs) - def test_x_ticks_are_rotated_for_time(self) -> None: + def test_dates_are_concise(self) -> None: + import matplotlib.dates as mdates + time = pd.date_range("2000-01-01", "2000-01-10") a = DataArray(np.random.randn(2, len(time)), [("xx", [1, 2]), ("t", time)]) - a.plot(x="t") - rotation = plt.gca().get_xticklabels()[0].get_rotation() - assert rotation != 0 + self.plotfunc(a, x="t") + + ax = plt.gca() + + assert isinstance(ax.xaxis.get_major_locator(), mdates.AutoDateLocator) + assert isinstance(ax.xaxis.get_major_formatter(), mdates.ConciseDateFormatter) def test_plot_nans(self) -> None: x1 = self.darray[:5] @@ -1888,6 +1898,25 @@ def test_interval_breaks_logspace(self) -> None: class TestImshow(Common2dMixin, PlotTestCase): plotfunc = staticmethod(xplt.imshow) + @pytest.mark.xfail( + reason=( + "Failing inside matplotlib. Should probably be fixed upstream because " + "other plot functions can handle it. " + "Remove this test when it works, already in Common2dMixin" + ) + ) + def test_dates_are_concise(self) -> None: + import matplotlib.dates as mdates + + time = pd.date_range("2000-01-01", "2000-01-10") + a = DataArray(np.random.randn(2, len(time)), [("xx", [1, 2]), ("t", time)]) + self.plotfunc(a, x="t") + + ax = plt.gca() + + assert isinstance(ax.xaxis.get_major_locator(), mdates.AutoDateLocator) + assert isinstance(ax.xaxis.get_major_formatter(), mdates.ConciseDateFormatter) + @pytest.mark.slow def test_imshow_called(self) -> None: # Having both statements ensures the test works properly @@ -2032,6 +2061,25 @@ class TestSurface(Common2dMixin, PlotTestCase): plotfunc = staticmethod(xplt.surface) subplot_kws = {"projection": "3d"} + @pytest.mark.xfail( + reason=( + "Failing inside matplotlib. Should probably be fixed upstream because " + "other plot functions can handle it. " + "Remove this test when it works, already in Common2dMixin" + ) + ) + def test_dates_are_concise(self) -> None: + import matplotlib.dates as mdates + + time = pd.date_range("2000-01-01", "2000-01-10") + a = DataArray(np.random.randn(2, len(time)), [("xx", [1, 2]), ("t", time)]) + self.plotfunc(a, x="t") + + ax = plt.gca() + + assert isinstance(ax.xaxis.get_major_locator(), mdates.AutoDateLocator) + assert isinstance(ax.xaxis.get_major_formatter(), mdates.ConciseDateFormatter) + def test_primitive_artist_returned(self) -> None: artist = self.plotmethod() assert isinstance(artist, mpl_toolkits.mplot3d.art3d.Poly3DCollection) From cb14f2fee49210a5d6b18731f9b9b10feee5f909 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Wed, 22 Nov 2023 00:01:12 -0800 Subject: [PATCH 04/14] Fix mypy tests (#8476) I was seeing an error in #8475 --- xarray/core/nputils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 316a77ead6a..bd33b7b6d8f 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -31,7 +31,7 @@ _HAS_NUMBAGG = Version(numbagg.__version__) >= Version("0.5.0") except ImportError: # use numpy methods instead - numbagg = np + numbagg = np # type: ignore _HAS_NUMBAGG = False From 41b1b8cede2b151c797e5679a6260d02ed71fe26 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Wed, 22 Nov 2023 08:45:02 -0800 Subject: [PATCH 05/14] Allow `rank` to run on dask arrays (#8475) --- xarray/core/variable.py | 27 ++++++++++++--------------- xarray/tests/test_variable.py | 20 ++++++++++++++++---- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index db109a40454..c2133d55aeb 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2063,6 +2063,7 @@ def rank(self, dim, pct=False): -------- Dataset.rank, DataArray.rank """ + # This could / should arguably be implemented at the DataArray & Dataset level if not OPTIONS["use_bottleneck"]: raise RuntimeError( "rank requires bottleneck to be enabled." @@ -2071,24 +2072,20 @@ def rank(self, dim, pct=False): import bottleneck as bn - data = self.data - - if is_duck_dask_array(data): - raise TypeError( - "rank does not work for arrays stored as dask " - "arrays. Load the data via .compute() or .load() " - "prior to calling this method." - ) - elif not isinstance(data, np.ndarray): - raise TypeError(f"rank is not implemented for {type(data)} objects.") - - axis = self.get_axis_num(dim) func = bn.nanrankdata if self.dtype.kind == "f" else bn.rankdata - ranked = func(data, axis=axis) + ranked = xr.apply_ufunc( + func, + self, + input_core_dims=[[dim]], + output_core_dims=[[dim]], + dask="parallelized", + kwargs=dict(axis=-1), + ).transpose(*self.dims) + if pct: - count = np.sum(~np.isnan(data), axis=axis, keepdims=True) + count = self.notnull().sum(dim) ranked /= count - return Variable(self.dims, ranked) + return ranked def rolling_window( self, dim, window, window_dim, center=False, fill_value=dtypes.NA diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 8a73e435977..d91cf85e4eb 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1878,9 +1878,20 @@ def test_quantile_out_of_bounds(self, q): @requires_dask @requires_bottleneck - def test_rank_dask_raises(self): - v = Variable(["x"], [3.0, 1.0, np.nan, 2.0, 4.0]).chunk(2) - with pytest.raises(TypeError, match=r"arrays stored as dask"): + def test_rank_dask(self): + # Instead of a single test here, we could parameterize the other tests for both + # arrays. But this is sufficient. + v = Variable( + ["x", "y"], [[30.0, 1.0, np.nan, 20.0, 4.0], [30.0, 1.0, np.nan, 20.0, 4.0]] + ).chunk(x=1) + expected = Variable( + ["x", "y"], [[4.0, 1.0, np.nan, 3.0, 2.0], [4.0, 1.0, np.nan, 3.0, 2.0]] + ) + assert_equal(v.rank("y").compute(), expected) + + with pytest.raises( + ValueError, match=r" with dask='parallelized' consists of multiple chunks" + ): v.rank("x") def test_rank_use_bottleneck(self): @@ -1912,7 +1923,8 @@ def test_rank(self): v_expect = Variable(["x"], [0.75, 0.25, np.nan, 0.5, 1.0]) assert_equal(v.rank("x", pct=True), v_expect) # invalid dim - with pytest.raises(ValueError, match=r"not found"): + with pytest.raises(ValueError): + # apply_ufunc error message isn't great here — `ValueError: tuple.index(x): x not in tuple` v.rank("y") def test_big_endian_reduce(self): From 398d8e6c393efe72978ed07ab2d37973771db7b3 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Wed, 22 Nov 2023 10:45:22 -0800 Subject: [PATCH 06/14] Add whatsnew for #8475 (#8478) Sorry, forgot in the original PR --- doc/whats-new.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3698058cfe8..76548fe95c5 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -26,6 +26,11 @@ New Features - Use a concise format when plotting datetime arrays. (:pull:`8449`). By `Jimmy Westling `_. + +- :py:meth:`~xarray.DataArray.rank` now operates on dask-backed arrays, assuming + the core dim has exactly one chunk. (:pull:`8475`). + By `Maximilian Roos `_. + Breaking changes ~~~~~~~~~~~~~~~~ From 71c2f6199f0aa569409d791dff6ee5e06f8c8665 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Fri, 24 Nov 2023 10:49:37 -0800 Subject: [PATCH 07/14] Improve "variable not found" error message (#8474) * Improve missing variable error message * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- doc/whats-new.rst | 3 +++ xarray/core/dataset.py | 10 +++++++++- xarray/core/formatting.py | 15 ++++++++++++++- xarray/tests/test_error_messages.py | 17 +++++++++++++++++ 4 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 xarray/tests/test_error_messages.py diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 76548fe95c5..a8302715317 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -46,6 +46,9 @@ Bug fixes Documentation ~~~~~~~~~~~~~ +- Improved error message when attempting to get a variable which doesn't exist from a Dataset. + (:pull:`8474`) + By `Maximilian Roos `_. Internal Changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index c8e7564d3ca..5d2d24d6723 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1539,10 +1539,18 @@ def __getitem__( Indexing with a list of names will return a new ``Dataset`` object. """ + from xarray.core.formatting import shorten_list_repr + if utils.is_dict_like(key): return self.isel(**key) if utils.hashable(key): - return self._construct_dataarray(key) + try: + return self._construct_dataarray(key) + except KeyError as e: + raise KeyError( + f"No variable named {key!r}. Variables on the dataset include {shorten_list_repr(list(self.variables.keys()), max_items=10)}" + ) from e + if utils.iterable_of_hashable(key): return self._copy_listed(key) raise ValueError(f"Unsupported key-type {type(key)}") diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index a915e9acbf3..ea0e6275fb6 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -6,7 +6,7 @@ import functools import math from collections import defaultdict -from collections.abc import Collection, Hashable +from collections.abc import Collection, Hashable, Sequence from datetime import datetime, timedelta from itertools import chain, zip_longest from reprlib import recursive_repr @@ -937,3 +937,16 @@ def diff_dataset_repr(a, b, compat): summary.append(diff_attrs_repr(a.attrs, b.attrs, compat)) return "\n".join(summary) + + +def shorten_list_repr(items: Sequence, max_items: int) -> str: + if len(items) <= max_items: + return repr(items) + else: + first_half = repr(items[: max_items // 2])[ + 1:-1 + ] # Convert to string and remove brackets + second_half = repr(items[-max_items // 2 :])[ + 1:-1 + ] # Convert to string and remove brackets + return f"[{first_half}, ..., {second_half}]" diff --git a/xarray/tests/test_error_messages.py b/xarray/tests/test_error_messages.py new file mode 100644 index 00000000000..b5840aafdfa --- /dev/null +++ b/xarray/tests/test_error_messages.py @@ -0,0 +1,17 @@ +""" +This new file is intended to test the quality & friendliness of error messages that are +raised by xarray. It's currently separate from the standard tests, which are more +focused on the functions working (though we could consider integrating them.). +""" + +import pytest + + +def test_no_var_in_dataset(ds): + with pytest.raises( + KeyError, + match=( + r"No variable named 'foo'. Variables on the dataset include \['z1', 'z2', 'x', 'time', 'c', 'y'\]" + ), + ): + ds["foo"] From dc66f0d2b34754fb2a8d29d8eb635a5b143755ad Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Fri, 24 Nov 2023 23:56:18 +0100 Subject: [PATCH 08/14] Fix bug for categorical pandas index with categories with EA dtype (#8481) * Fix bug for categorical pandas index with categories with EA dtype * Add whatsnew * Update xarray/tests/test_dataset.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --------- Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- doc/whats-new.rst | 2 ++ xarray/core/utils.py | 2 ++ xarray/tests/test_dataset.py | 11 +++++++++++ 3 files changed, 15 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index a8302715317..d92f3239f60 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -42,6 +42,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- Fix dtype inference for ``pd.CategoricalIndex`` when categories are backed by a ``pd.ExtensionDtype`` (:pull:`8481`) + Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/utils.py b/xarray/core/utils.py index ad86b2c7fec..9ba4a43f6d9 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -114,6 +114,8 @@ def get_valid_numpy_dtype(array: np.ndarray | pd.Index): elif hasattr(array, "categories"): # category isn't a real numpy dtype dtype = array.categories.dtype + if not is_valid_numpy_dtype(dtype): + dtype = np.dtype("O") elif not is_valid_numpy_dtype(array.dtype): dtype = np.dtype("O") else: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index ff7703a1cf5..a53d81e36af 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4697,6 +4697,17 @@ def test_from_dataframe_categorical(self) -> None: assert len(ds["i1"]) == 2 assert len(ds["i2"]) == 2 + def test_from_dataframe_categorical_string_categories(self) -> None: + cat = pd.CategoricalIndex( + pd.Categorical.from_codes( + np.array([1, 1, 0, 2]), + categories=pd.Index(["foo", "bar", "baz"], dtype="string"), + ) + ) + ser = pd.Series(1, index=cat) + ds = ser.to_xarray() + assert ds.coords.dtypes["index"] == np.dtype("O") + @requires_sparse def test_from_dataframe_sparse(self) -> None: import sparse From a6837909ded8c3fe2d9ff8655b04d18f46d77ed5 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sat, 25 Nov 2023 13:06:08 -0800 Subject: [PATCH 09/14] Use numbagg for `ffill` by default (#8389) * Use `numbagg` for `ffill` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use duck_array_ops for numbagg version, test import is lazy * Update xarray/core/duck_array_ops.py Co-authored-by: Deepak Cherian * Update xarray/core/nputils.py Co-authored-by: Deepak Cherian * Update xarray/core/rolling_exp.py Co-authored-by: Deepak Cherian * Update xarray/core/nputils.py Co-authored-by: Deepak Cherian --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian --- doc/whats-new.rst | 4 +++ xarray/backends/zarr.py | 4 +-- xarray/core/dask_array_ops.py | 5 ++-- xarray/core/duck_array_ops.py | 41 +++++++++++++++++++++++++--- xarray/core/missing.py | 12 +-------- xarray/core/nputils.py | 34 +++++++++++------------- xarray/core/pycompat.py | 5 +++- xarray/core/rolling_exp.py | 50 ++++++++++++++++++----------------- xarray/tests/__init__.py | 7 ++++- xarray/tests/test_missing.py | 24 ++++++++++++----- xarray/tests/test_plugins.py | 29 ++++++++++---------- 11 files changed, 132 insertions(+), 83 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d92f3239f60..b2efe650e28 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -55,6 +55,10 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ +- :py:meth:`DataArray.bfill` & :py:meth:`DataArray.ffill` now use numbagg by + default, which is up to 5x faster where parallelization is possible. (:pull:`8339`) + By `Maximilian Roos `_. + .. _whats-new.2023.11.0: v2023.11.0 (Nov 16, 2023) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 6632e40cf6f..f0eece3bb61 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -177,8 +177,8 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name, safe_chunks): # DESIGN CHOICE: do not allow multiple dask chunks on a single zarr chunk # this avoids the need to get involved in zarr synchronization / locking # From zarr docs: - # "If each worker in a parallel computation is writing to a separate - # region of the array, and if region boundaries are perfectly aligned + # "If each worker in a parallel computation is writing to a + # separate region of the array, and if region boundaries are perfectly aligned # with chunk boundaries, then no synchronization is required." # TODO: incorporate synchronizer to allow writes from multiple dask # threads diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index d2d3e4a6d1c..98ff9002856 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -59,10 +59,11 @@ def push(array, n, axis): """ Dask-aware bottleneck.push """ - import bottleneck import dask.array as da import numpy as np + from xarray.core.duck_array_ops import _push + def _fill_with_last_one(a, b): # cumreduction apply the push func over all the blocks first so, the only missing part is filling # the missing values using the last data of the previous chunk @@ -85,7 +86,7 @@ def _fill_with_last_one(a, b): # The method parameter makes that the tests for python 3.7 fails. return da.reductions.cumreduction( - func=bottleneck.push, + func=_push, binop=_fill_with_last_one, ident=np.nan, x=array, diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index b9f7db9737f..7f2b2ed85ee 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -31,8 +31,10 @@ from numpy import concatenate as _concatenate from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined] from numpy.lib.stride_tricks import sliding_window_view # noqa +from packaging.version import Version -from xarray.core import dask_array_ops, dtypes, nputils +from xarray.core import dask_array_ops, dtypes, nputils, pycompat +from xarray.core.options import OPTIONS from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array from xarray.core.pycompat import array_type, is_duck_dask_array from xarray.core.utils import is_duck_array, module_available @@ -688,13 +690,44 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna) -def push(array, n, axis): - from bottleneck import push +def _push(array, n: int | None = None, axis: int = -1): + """ + Use either bottleneck or numbagg depending on options & what's available + """ + + if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]: + raise RuntimeError( + "ffill & bfill requires bottleneck or numbagg to be enabled." + " Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one." + ) + if OPTIONS["use_numbagg"] and module_available("numbagg"): + import numbagg + + if pycompat.mod_version("numbagg") < Version("0.6.2"): + warnings.warn( + f"numbagg >= 0.6.2 is required for bfill & ffill; {pycompat.mod_version('numbagg')} is installed. We'll attempt with bottleneck instead." + ) + else: + return numbagg.ffill(array, limit=n, axis=axis) + + # work around for bottleneck 178 + limit = n if n is not None else array.shape[axis] + + import bottleneck as bn + + return bn.push(array, limit, axis) + +def push(array, n, axis): + if not OPTIONS["use_bottleneck"] and not OPTIONS["use_numbagg"]: + raise RuntimeError( + "ffill & bfill requires bottleneck or numbagg to be enabled." + " Call `xr.set_options(use_bottleneck=True)` or `xr.set_options(use_numbagg=True)` to enable one." + ) if is_duck_dask_array(array): return dask_array_ops.push(array, n, axis) else: - return push(array, n, axis) + return _push(array, n, axis) def _first_last_wrapper(array, *, axis, op, keepdims): diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 90a9dd2e76c..b55fd6049a6 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -14,7 +14,7 @@ from xarray.core.common import _contains_datetime_like_objects, ones_like from xarray.core.computation import apply_ufunc from xarray.core.duck_array_ops import datetime_to_numeric, push, timedelta_to_numeric -from xarray.core.options import OPTIONS, _get_keep_attrs +from xarray.core.options import _get_keep_attrs from xarray.core.parallelcompat import get_chunked_array_type, is_chunked_array from xarray.core.types import Interp1dOptions, InterpOptions from xarray.core.utils import OrderedSet, is_scalar @@ -413,11 +413,6 @@ def _bfill(arr, n=None, axis=-1): def ffill(arr, dim=None, limit=None): """forward fill missing values""" - if not OPTIONS["use_bottleneck"]: - raise RuntimeError( - "ffill requires bottleneck to be enabled." - " Call `xr.set_options(use_bottleneck=True)` to enable it." - ) axis = arr.get_axis_num(dim) @@ -436,11 +431,6 @@ def ffill(arr, dim=None, limit=None): def bfill(arr, dim=None, limit=None): """backfill missing values""" - if not OPTIONS["use_bottleneck"]: - raise RuntimeError( - "bfill requires bottleneck to be enabled." - " Call `xr.set_options(use_bottleneck=True)` to enable it." - ) axis = arr.get_axis_num(dim) diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index bd33b7b6d8f..96e5548b9b4 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -1,12 +1,16 @@ from __future__ import annotations import warnings +from typing import Callable import numpy as np import pandas as pd from numpy.core.multiarray import normalize_axis_index # type: ignore[attr-defined] from packaging.version import Version +from xarray.core import pycompat +from xarray.core.utils import module_available + # remove once numpy 2.0 is the oldest supported version try: from numpy.exceptions import RankWarning # type: ignore[attr-defined,unused-ignore] @@ -25,15 +29,6 @@ bn = np _BOTTLENECK_AVAILABLE = False -try: - import numbagg - - _HAS_NUMBAGG = Version(numbagg.__version__) >= Version("0.5.0") -except ImportError: - # use numpy methods instead - numbagg = np # type: ignore - _HAS_NUMBAGG = False - def _select_along_axis(values, idx, axis): other_ind = np.ix_(*[np.arange(s) for s in idx.shape]) @@ -171,17 +166,16 @@ def __setitem__(self, key, value): self._array[key] = np.moveaxis(value, vindex_positions, mixed_positions) -def _create_method(name, npmodule=np): +def _create_method(name, npmodule=np) -> Callable: def f(values, axis=None, **kwargs): dtype = kwargs.get("dtype", None) bn_func = getattr(bn, name, None) - nba_func = getattr(numbagg, name, None) if ( - _HAS_NUMBAGG + module_available("numbagg") + and pycompat.mod_version("numbagg") >= Version("0.5.0") and OPTIONS["use_numbagg"] and isinstance(values, np.ndarray) - and nba_func is not None # numbagg uses ddof=1 only, but numpy uses ddof=0 by default and (("var" in name or "std" in name) and kwargs.get("ddof", 0) == 1) # TODO: bool? @@ -189,11 +183,15 @@ def f(values, axis=None, **kwargs): # and values.dtype.isnative and (dtype is None or np.dtype(dtype) == values.dtype) ): - # numbagg does not take care dtype, ddof - kwargs.pop("dtype", None) - kwargs.pop("ddof", None) - result = nba_func(values, axis=axis, **kwargs) - elif ( + import numbagg + + nba_func = getattr(numbagg, name, None) + if nba_func is not None: + # numbagg does not take care dtype, ddof + kwargs.pop("dtype", None) + kwargs.pop("ddof", None) + return nba_func(values, axis=axis, **kwargs) + if ( _BOTTLENECK_AVAILABLE and OPTIONS["use_bottleneck"] and isinstance(values, np.ndarray) diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index bc8b61164f1..32ef408f7cc 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -12,7 +12,7 @@ integer_types = (int, np.integer) if TYPE_CHECKING: - ModType = Literal["dask", "pint", "cupy", "sparse", "cubed"] + ModType = Literal["dask", "pint", "cupy", "sparse", "cubed", "numbagg"] DuckArrayTypes = tuple[type[Any], ...] # TODO: improve this? maybe Generic @@ -47,6 +47,9 @@ def __init__(self, mod: ModType) -> None: duck_array_type = (duck_array_module.SparseArray,) elif mod == "cubed": duck_array_type = (duck_array_module.Array,) + # Not a duck array module, but using this system regardless, to get lazy imports + elif mod == "numbagg": + duck_array_type = () else: raise NotImplementedError diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index 04d7dd41966..1e4b805208f 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -6,18 +6,12 @@ import numpy as np from packaging.version import Version +from xarray.core import pycompat from xarray.core.computation import apply_ufunc from xarray.core.options import _get_keep_attrs from xarray.core.pdcompat import count_not_none from xarray.core.types import T_DataWithCoords - -try: - import numbagg - from numbagg import move_exp_nanmean, move_exp_nansum - - _NUMBAGG_VERSION: Version | None = Version(numbagg.__version__) -except ImportError: - _NUMBAGG_VERSION = None +from xarray.core.utils import module_available def _get_alpha( @@ -83,17 +77,17 @@ def __init__( window_type: str = "span", min_weight: float = 0.0, ): - if _NUMBAGG_VERSION is None: + if not module_available("numbagg"): raise ImportError( "numbagg >= 0.2.1 is required for rolling_exp but currently numbagg is not installed" ) - elif _NUMBAGG_VERSION < Version("0.2.1"): + elif pycompat.mod_version("numbagg") < Version("0.2.1"): raise ImportError( - f"numbagg >= 0.2.1 is required for rolling_exp but currently version {_NUMBAGG_VERSION} is installed" + f"numbagg >= 0.2.1 is required for rolling_exp but currently version {pycompat.mod_version('numbagg')} is installed" ) - elif _NUMBAGG_VERSION < Version("0.3.1") and min_weight > 0: + elif pycompat.mod_version("numbagg") < Version("0.3.1") and min_weight > 0: raise ImportError( - f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {_NUMBAGG_VERSION} is installed" + f"numbagg >= 0.3.1 is required for `min_weight > 0` within `.rolling_exp` but currently version {pycompat.mod_version('numbagg')} is installed" ) self.obj: T_DataWithCoords = obj @@ -127,13 +121,15 @@ def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords: Dimensions without coordinates: x """ + import numbagg + if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) dim_order = self.obj.dims return apply_ufunc( - move_exp_nanmean, + numbagg.move_exp_nanmean, self.obj, input_core_dims=[[self.dim]], kwargs=self.kwargs, @@ -163,13 +159,15 @@ def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords: Dimensions without coordinates: x """ + import numbagg + if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) dim_order = self.obj.dims return apply_ufunc( - move_exp_nansum, + numbagg.move_exp_nansum, self.obj, input_core_dims=[[self.dim]], kwargs=self.kwargs, @@ -194,10 +192,12 @@ def std(self) -> T_DataWithCoords: Dimensions without coordinates: x """ - if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"): + if pycompat.mod_version("numbagg") < Version("0.4.0"): raise ImportError( - f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {_NUMBAGG_VERSION} is installed" + f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed" ) + import numbagg + dim_order = self.obj.dims return apply_ufunc( @@ -225,12 +225,12 @@ def var(self) -> T_DataWithCoords: array([ nan, 0. , 0.46153846, 0.18461538, 0.06446281]) Dimensions without coordinates: x """ - - if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"): + if pycompat.mod_version("numbagg") < Version("0.4.0"): raise ImportError( - f"numbagg >= 0.4.0 is required for rolling_exp().var(), currently {_NUMBAGG_VERSION} is installed" + f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed" ) dim_order = self.obj.dims + import numbagg return apply_ufunc( numbagg.move_exp_nanvar, @@ -258,11 +258,12 @@ def cov(self, other: T_DataWithCoords) -> T_DataWithCoords: Dimensions without coordinates: x """ - if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"): + if pycompat.mod_version("numbagg") < Version("0.4.0"): raise ImportError( - f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {_NUMBAGG_VERSION} is installed" + f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed" ) dim_order = self.obj.dims + import numbagg return apply_ufunc( numbagg.move_exp_nancov, @@ -291,11 +292,12 @@ def corr(self, other: T_DataWithCoords) -> T_DataWithCoords: Dimensions without coordinates: x """ - if _NUMBAGG_VERSION is None or _NUMBAGG_VERSION < Version("0.4.0"): + if pycompat.mod_version("numbagg") < Version("0.4.0"): raise ImportError( - f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {_NUMBAGG_VERSION} is installed" + f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed" ) dim_order = self.obj.dims + import numbagg return apply_ufunc( numbagg.move_exp_nancorr, diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 8b5cf456bcb..f7f8f823d78 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -53,7 +53,8 @@ def _importorskip( mod = importlib.import_module(modname) has = True if minversion is not None: - if Version(mod.__version__) < Version(minversion): + v = getattr(mod, "__version__", "999") + if Version(v) < Version(minversion): raise ImportError("Minimum version not satisfied") except ImportError: has = False @@ -96,6 +97,10 @@ def _importorskip( requires_scipy_or_netCDF4 = pytest.mark.skipif( not has_scipy_or_netCDF4, reason="requires scipy or netCDF4" ) +has_numbagg_or_bottleneck = has_numbagg or has_bottleneck +requires_numbagg_or_bottleneck = pytest.mark.skipif( + not has_scipy_or_netCDF4, reason="requires scipy or netCDF4" +) # _importorskip does not work for development versions has_pandas_version_two = Version(pd.__version__).major >= 2 requires_pandas_version_two = pytest.mark.skipif( diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py index e318bf01a7e..20a54c3ed53 100644 --- a/xarray/tests/test_missing.py +++ b/xarray/tests/test_missing.py @@ -24,6 +24,8 @@ requires_bottleneck, requires_cftime, requires_dask, + requires_numbagg, + requires_numbagg_or_bottleneck, requires_scipy, ) @@ -407,7 +409,7 @@ def test_interpolate_dask_expected_dtype(dtype, method): assert da.dtype == da.compute().dtype -@requires_bottleneck +@requires_numbagg_or_bottleneck def test_ffill(): da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") expected = xr.DataArray(np.array([4, 5, 5], dtype=np.float64), dims="x") @@ -415,9 +417,9 @@ def test_ffill(): assert_equal(actual, expected) -def test_ffill_use_bottleneck(): +def test_ffill_use_bottleneck_numbagg(): da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") - with xr.set_options(use_bottleneck=False): + with xr.set_options(use_bottleneck=False, use_numbagg=False): with pytest.raises(RuntimeError): da.ffill("x") @@ -426,14 +428,24 @@ def test_ffill_use_bottleneck(): def test_ffill_use_bottleneck_dask(): da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") da = da.chunk({"x": 1}) - with xr.set_options(use_bottleneck=False): + with xr.set_options(use_bottleneck=False, use_numbagg=False): with pytest.raises(RuntimeError): da.ffill("x") +@requires_numbagg +@requires_dask +def test_ffill_use_numbagg_dask(): + with xr.set_options(use_bottleneck=False): + da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") + da = da.chunk(x=-1) + # Succeeds with a single chunk: + _ = da.ffill("x").compute() + + def test_bfill_use_bottleneck(): da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") - with xr.set_options(use_bottleneck=False): + with xr.set_options(use_bottleneck=False, use_numbagg=False): with pytest.raises(RuntimeError): da.bfill("x") @@ -442,7 +454,7 @@ def test_bfill_use_bottleneck(): def test_bfill_use_bottleneck_dask(): da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x") da = da.chunk({"x": 1}) - with xr.set_options(use_bottleneck=False): + with xr.set_options(use_bottleneck=False, use_numbagg=False): with pytest.raises(RuntimeError): da.bfill("x") diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py index 1af255d30bb..b518c973d3a 100644 --- a/xarray/tests/test_plugins.py +++ b/xarray/tests/test_plugins.py @@ -218,28 +218,29 @@ def test_lazy_import() -> None: When importing xarray these should not be imported as well. Only when running code for the first time that requires them. """ - blacklisted = [ + deny_list = [ + "cubed", + "cupy", + # "dask", # TODO: backends.locks is not lazy yet :( + "dask.array", + "dask.distributed", + "flox", "h5netcdf", + "matplotlib", + "nc_time_axis", "netCDF4", - "pydap", "Nio", + "numbagg", + "pint", + "pydap", "scipy", - "zarr", - "matplotlib", - "nc_time_axis", - "flox", - # "dask", # TODO: backends.locks is not lazy yet :( - "dask.array", - "dask.distributed", "sparse", - "cupy", - "pint", - "cubed", + "zarr", ] # ensure that none of the above modules has been imported before modules_backup = {} for pkg in list(sys.modules.keys()): - for mod in blacklisted + ["xarray"]: + for mod in deny_list + ["xarray"]: if pkg.startswith(mod): modules_backup[pkg] = sys.modules[pkg] del sys.modules[pkg] @@ -255,7 +256,7 @@ def test_lazy_import() -> None: # lazy loaded are loaded when importing xarray is_imported = set() for pkg in sys.modules: - for mod in blacklisted: + for mod in deny_list: if pkg.startswith(mod): is_imported.add(mod) break From 633e66a64e364c42b59294bd5ce60c6627a18d25 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Sat, 25 Nov 2023 13:55:19 -0800 Subject: [PATCH 10/14] Refine rolling_exp error messages (#8485) (Sorry, copy & pasted too liberally!) --- xarray/core/rolling_exp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index 1e4b805208f..144e26a86b2 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -227,7 +227,7 @@ def var(self) -> T_DataWithCoords: """ if pycompat.mod_version("numbagg") < Version("0.4.0"): raise ImportError( - f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed" + f"numbagg >= 0.4.0 is required for rolling_exp().var(), currently {pycompat.mod_version('numbagg')} is installed" ) dim_order = self.obj.dims import numbagg @@ -260,7 +260,7 @@ def cov(self, other: T_DataWithCoords) -> T_DataWithCoords: if pycompat.mod_version("numbagg") < Version("0.4.0"): raise ImportError( - f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed" + f"numbagg >= 0.4.0 is required for rolling_exp().cov(), currently {pycompat.mod_version('numbagg')} is installed" ) dim_order = self.obj.dims import numbagg @@ -294,7 +294,7 @@ def corr(self, other: T_DataWithCoords) -> T_DataWithCoords: if pycompat.mod_version("numbagg") < Version("0.4.0"): raise ImportError( - f"numbagg >= 0.4.0 is required for rolling_exp().std(), currently {pycompat.mod_version('numbagg')} is installed" + f"numbagg >= 0.4.0 is required for rolling_exp().corr(), currently {pycompat.mod_version('numbagg')} is installed" ) dim_order = self.obj.dims import numbagg From d54c461c0af9f7d0945862ebc9dec1a3b0eacca6 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Mon, 27 Nov 2023 12:56:56 -0800 Subject: [PATCH 11/14] Fix Zarr region transpose (#8484) * Fix Zarr region transpose This wasn't working on an unregion-ed write; I think because `new_var` was being lost. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- doc/whats-new.rst | 2 ++ xarray/backends/zarr.py | 8 +++----- xarray/tests/test_backends.py | 9 +++++++-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b2efe650e28..71cbc1a08ee 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -43,6 +43,8 @@ Bug fixes ~~~~~~~~~ - Fix dtype inference for ``pd.CategoricalIndex`` when categories are backed by a ``pd.ExtensionDtype`` (:pull:`8481`) +- Fix writing a variable that requires transposing when not writing to a region (:pull:`8484`) + By `Maximilian Roos `_. Documentation diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index f0eece3bb61..7f1af10b45a 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -624,12 +624,10 @@ def store( variables_encoded.update(vars_with_encoding) for var_name in existing_variable_names: - new_var = variables_encoded[var_name] - existing_var = existing_vars[var_name] - new_var = _validate_and_transpose_existing_dims( + variables_encoded[var_name] = _validate_and_transpose_existing_dims( var_name, - new_var, - existing_var, + variables_encoded[var_name], + existing_vars[var_name], self._write_region, self._append_dim, ) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 85248b5c40a..0704dd835c0 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -5423,7 +5423,7 @@ def test_zarr_region_append(self, tmp_path): @requires_zarr -def test_zarr_region_transpose(tmp_path): +def test_zarr_region(tmp_path): x = np.arange(0, 50, 10) y = np.arange(0, 20, 2) data = np.ones((5, 10)) @@ -5438,7 +5438,12 @@ def test_zarr_region_transpose(tmp_path): ) ds.to_zarr(tmp_path / "test.zarr") - ds_region = 1 + ds.isel(x=[0], y=[0]).transpose() + ds_transposed = ds.transpose("y", "x") + + ds_region = 1 + ds_transposed.isel(x=[0], y=[0]) ds_region.to_zarr( tmp_path / "test.zarr", region={"x": slice(0, 1), "y": slice(0, 1)} ) + + # Write without region + ds_transposed.to_zarr(tmp_path / "test.zarr", mode="r+") From d3a15274b41810efc656bc4aeec0e1955cf2be32 Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Tue, 28 Nov 2023 01:29:43 -0500 Subject: [PATCH 12/14] Reduce redundancy between namedarray and variable tests (#8405) Co-authored-by: Deepak Cherian Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Co-authored-by: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> --- xarray/tests/test_namedarray.py | 788 +++++++++++++++++--------------- xarray/tests/test_variable.py | 44 +- 2 files changed, 420 insertions(+), 412 deletions(-) diff --git a/xarray/tests/test_namedarray.py b/xarray/tests/test_namedarray.py index e0141e12755..fcdf063d106 100644 --- a/xarray/tests/test_namedarray.py +++ b/xarray/tests/test_namedarray.py @@ -2,6 +2,7 @@ import copy import warnings +from abc import abstractmethod from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Generic, cast, overload @@ -57,387 +58,420 @@ def __array_namespace__(self) -> ModuleType: return np -@pytest.fixture -def random_inputs() -> np.ndarray[Any, np.dtype[np.float32]]: - return np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) - - -def test_namedarray_init() -> None: - dtype = np.dtype(np.int8) - expected = np.array([1, 2], dtype=dtype) - actual: NamedArray[Any, np.dtype[np.int8]] - actual = NamedArray(("x",), expected) - assert np.array_equal(np.asarray(actual.data), expected) - - with pytest.raises(AttributeError): - expected2 = [1, 2] - actual2: NamedArray[Any, Any] - actual2 = NamedArray(("x",), expected2) # type: ignore[arg-type] - assert np.array_equal(np.asarray(actual2.data), expected2) - - -@pytest.mark.parametrize( - "dims, data, expected, raise_error", - [ - (("x",), [1, 2, 3], np.array([1, 2, 3]), False), - ((1,), np.array([4, 5, 6]), np.array([4, 5, 6]), False), - ((), 2, np.array(2), False), - # Fail: - (("x",), NamedArray("time", np.array([1, 2, 3])), np.array([1, 2, 3]), True), - ], -) -def test_from_array( - dims: _DimsLike, - data: ArrayLike, - expected: np.ndarray[Any, Any], - raise_error: bool, -) -> None: - actual: NamedArray[Any, Any] - if raise_error: - with pytest.raises(TypeError, match="already a Named array"): - actual = from_array(dims, data) - - # Named arrays are not allowed: - from_array(actual) # type: ignore[call-overload] - else: - actual = from_array(dims, data) - +class NamedArraySubclassobjects: + @pytest.fixture + def target(self, data: np.ndarray[Any, Any]) -> Any: + """Fixture that needs to be overridden""" + raise NotImplementedError + + @abstractmethod + def cls(self, *args: Any, **kwargs: Any) -> Any: + """Method that needs to be overridden""" + raise NotImplementedError + + @pytest.fixture + def data(self) -> np.ndarray[Any, np.dtype[Any]]: + return 0.5 * np.arange(10).reshape(2, 5) + + @pytest.fixture + def random_inputs(self) -> np.ndarray[Any, np.dtype[np.float32]]: + return np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + + def test_properties(self, target: Any, data: Any) -> None: + assert target.dims == ("x", "y") + assert np.array_equal(target.data, data) + assert target.dtype == float + assert target.shape == (2, 5) + assert target.ndim == 2 + assert target.sizes == {"x": 2, "y": 5} + assert target.size == 10 + assert target.nbytes == 80 + assert len(target) == 2 + + def test_attrs(self, target: Any) -> None: + assert target.attrs == {} + attrs = {"foo": "bar"} + target.attrs = attrs + assert target.attrs == attrs + assert isinstance(target.attrs, dict) + target.attrs["foo"] = "baz" + assert target.attrs["foo"] == "baz" + + @pytest.mark.parametrize( + "expected", [np.array([1, 2], dtype=np.dtype(np.int8)), [1, 2]] + ) + def test_init(self, expected: Any) -> None: + actual = self.cls(("x",), expected) assert np.array_equal(np.asarray(actual.data), expected) + actual = self.cls(("x",), expected) + assert np.array_equal(np.asarray(actual.data), expected) -def test_from_array_with_masked_array() -> None: - masked_array: np.ndarray[Any, np.dtype[np.generic]] - masked_array = np.ma.array([1, 2, 3], mask=[False, True, False]) # type: ignore[no-untyped-call] - with pytest.raises(NotImplementedError): - from_array(("x",), masked_array) - - -def test_from_array_with_0d_object() -> None: - data = np.empty((), dtype=object) - data[()] = (10, 12, 12) - narr = from_array((), data) - np.array_equal(np.asarray(narr.data), data) - - -# TODO: Make xr.core.indexing.ExplicitlyIndexed pass as a subclass of_arrayfunction_or_api -# and remove this test. -def test_from_array_with_explicitly_indexed( - random_inputs: np.ndarray[Any, Any] -) -> None: - array: CustomArray[Any, Any] - array = CustomArray(random_inputs) - output: NamedArray[Any, Any] - output = from_array(("x", "y", "z"), array) - assert isinstance(output.data, np.ndarray) - - array2: CustomArrayIndexable[Any, Any] - array2 = CustomArrayIndexable(random_inputs) - output2: NamedArray[Any, Any] - output2 = from_array(("x", "y", "z"), array2) - assert isinstance(output2.data, CustomArrayIndexable) - - -def test_properties() -> None: - data = 0.5 * np.arange(10).reshape(2, 5) - named_array: NamedArray[Any, Any] - named_array = NamedArray(["x", "y"], data, {"key": "value"}) - assert named_array.dims == ("x", "y") - assert np.array_equal(np.asarray(named_array.data), data) - assert named_array.attrs == {"key": "value"} - assert named_array.ndim == 2 - assert named_array.sizes == {"x": 2, "y": 5} - assert named_array.size == 10 - assert named_array.nbytes == 80 - assert len(named_array) == 2 - - -def test_attrs() -> None: - named_array: NamedArray[Any, Any] - named_array = NamedArray(["x", "y"], np.arange(10).reshape(2, 5)) - assert named_array.attrs == {} - named_array.attrs["key"] = "value" - assert named_array.attrs == {"key": "value"} - named_array.attrs = {"key": "value2"} - assert named_array.attrs == {"key": "value2"} - - -def test_data(random_inputs: np.ndarray[Any, Any]) -> None: - named_array: NamedArray[Any, Any] - named_array = NamedArray(["x", "y", "z"], random_inputs) - assert np.array_equal(np.asarray(named_array.data), random_inputs) - with pytest.raises(ValueError): - named_array.data = np.random.random((3, 4)).astype(np.float64) - - -def test_real_and_imag() -> None: - expected_real: np.ndarray[Any, np.dtype[np.float64]] - expected_real = np.arange(3, dtype=np.float64) - - expected_imag: np.ndarray[Any, np.dtype[np.float64]] - expected_imag = -np.arange(3, dtype=np.float64) - - arr: np.ndarray[Any, np.dtype[np.complex128]] - arr = expected_real + 1j * expected_imag - - named_array: NamedArray[Any, np.dtype[np.complex128]] - named_array = NamedArray(["x"], arr) - - actual_real: duckarray[Any, np.dtype[np.float64]] = named_array.real.data - assert np.array_equal(np.asarray(actual_real), expected_real) - assert actual_real.dtype == expected_real.dtype - - actual_imag: duckarray[Any, np.dtype[np.float64]] = named_array.imag.data - assert np.array_equal(np.asarray(actual_imag), expected_imag) - assert actual_imag.dtype == expected_imag.dtype - - -# Additional tests as per your original class-based code -@pytest.mark.parametrize( - "data, dtype", - [ - ("foo", np.dtype("U3")), - (b"foo", np.dtype("S3")), - ], -) -def test_0d_string(data: Any, dtype: DTypeLike) -> None: - named_array: NamedArray[Any, Any] - named_array = from_array([], data) - assert named_array.data == data - assert named_array.dims == () - assert named_array.sizes == {} - assert named_array.attrs == {} - assert named_array.ndim == 0 - assert named_array.size == 1 - assert named_array.dtype == dtype - - -def test_0d_object() -> None: - named_array: NamedArray[Any, Any] - named_array = from_array([], (10, 12, 12)) - expected_data = np.empty((), dtype=object) - expected_data[()] = (10, 12, 12) - assert np.array_equal(np.asarray(named_array.data), expected_data) - - assert named_array.dims == () - assert named_array.sizes == {} - assert named_array.attrs == {} - assert named_array.ndim == 0 - assert named_array.size == 1 - assert named_array.dtype == np.dtype("O") - - -def test_0d_datetime() -> None: - named_array: NamedArray[Any, Any] - named_array = from_array([], np.datetime64("2000-01-01")) - assert named_array.dtype == np.dtype("datetime64[D]") - - -@pytest.mark.parametrize( - "timedelta, expected_dtype", - [ - (np.timedelta64(1, "D"), np.dtype("timedelta64[D]")), - (np.timedelta64(1, "s"), np.dtype("timedelta64[s]")), - (np.timedelta64(1, "m"), np.dtype("timedelta64[m]")), - (np.timedelta64(1, "h"), np.dtype("timedelta64[h]")), - (np.timedelta64(1, "us"), np.dtype("timedelta64[us]")), - (np.timedelta64(1, "ns"), np.dtype("timedelta64[ns]")), - (np.timedelta64(1, "ps"), np.dtype("timedelta64[ps]")), - (np.timedelta64(1, "fs"), np.dtype("timedelta64[fs]")), - (np.timedelta64(1, "as"), np.dtype("timedelta64[as]")), - ], -) -def test_0d_timedelta( - timedelta: np.timedelta64, expected_dtype: np.dtype[np.timedelta64] -) -> None: - named_array: NamedArray[Any, Any] - named_array = from_array([], timedelta) - assert named_array.dtype == expected_dtype - assert named_array.data == timedelta - - -@pytest.mark.parametrize( - "dims, data_shape, new_dims, raises", - [ - (["x", "y", "z"], (2, 3, 4), ["a", "b", "c"], False), - (["x", "y", "z"], (2, 3, 4), ["a", "b"], True), - (["x", "y", "z"], (2, 4, 5), ["a", "b", "c", "d"], True), - ([], [], (), False), - ([], [], ("x",), True), - ], -) -def test_dims_setter(dims: Any, data_shape: Any, new_dims: Any, raises: bool) -> None: - named_array: NamedArray[Any, Any] - named_array = NamedArray(dims, np.asarray(np.random.random(data_shape))) - assert named_array.dims == tuple(dims) - if raises: + def test_data(self, random_inputs: Any) -> None: + expected = self.cls(["x", "y", "z"], random_inputs) + assert np.array_equal(np.asarray(expected.data), random_inputs) with pytest.raises(ValueError): - named_array.dims = new_dims - else: - named_array.dims = new_dims - assert named_array.dims == tuple(new_dims) - - -def test_duck_array_class() -> None: - def test_duck_array_typevar(a: duckarray[Any, _DType]) -> duckarray[Any, _DType]: - # Mypy checks a is valid: - b: duckarray[Any, _DType] = a - - # Runtime check if valid: - if isinstance(b, _arrayfunction_or_api): - return b + expected.data = np.random.random((3, 4)).astype(np.float64) + d2 = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5)) + expected.data = d2 + assert np.array_equal(np.asarray(expected.data), d2) + + +class TestNamedArray(NamedArraySubclassobjects): + def cls(self, *args: Any, **kwargs: Any) -> NamedArray[Any, Any]: + return NamedArray(*args, **kwargs) + + @pytest.fixture + def target(self, data: np.ndarray[Any, Any]) -> NamedArray[Any, Any]: + return NamedArray(["x", "y"], data) + + @pytest.mark.parametrize( + "expected", + [ + np.array([1, 2], dtype=np.dtype(np.int8)), + pytest.param( + [1, 2], + marks=pytest.mark.xfail( + reason="NamedArray only supports array-like objects" + ), + ), + ], + ) + def test_init(self, expected: Any) -> None: + super().test_init(expected) + + @pytest.mark.parametrize( + "dims, data, expected, raise_error", + [ + (("x",), [1, 2, 3], np.array([1, 2, 3]), False), + ((1,), np.array([4, 5, 6]), np.array([4, 5, 6]), False), + ((), 2, np.array(2), False), + # Fail: + ( + ("x",), + NamedArray("time", np.array([1, 2, 3])), + np.array([1, 2, 3]), + True, + ), + ], + ) + def test_from_array( + self, + dims: _DimsLike, + data: ArrayLike, + expected: np.ndarray[Any, Any], + raise_error: bool, + ) -> None: + actual: NamedArray[Any, Any] + if raise_error: + with pytest.raises(TypeError, match="already a Named array"): + actual = from_array(dims, data) + + # Named arrays are not allowed: + from_array(actual) # type: ignore[call-overload] else: - raise TypeError(f"a ({type(a)}) is not a valid _arrayfunction or _arrayapi") - - numpy_a: NDArray[np.int64] - numpy_a = np.array([2.1, 4], dtype=np.dtype(np.int64)) - test_duck_array_typevar(numpy_a) - - masked_a: np.ma.MaskedArray[Any, np.dtype[np.int64]] - masked_a = np.ma.asarray([2.1, 4], dtype=np.dtype(np.int64)) # type: ignore[no-untyped-call] - test_duck_array_typevar(masked_a) - - custom_a: CustomArrayIndexable[Any, np.dtype[np.int64]] - custom_a = CustomArrayIndexable(numpy_a) - test_duck_array_typevar(custom_a) - - # Test numpy's array api: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - r"The numpy.array_api submodule is still experimental", - category=UserWarning, - ) - import numpy.array_api as nxp - - # TODO: nxp doesn't use dtype typevars, so can only use Any for the moment: - arrayapi_a: duckarray[Any, Any] # duckarray[Any, np.dtype[np.int64]] - arrayapi_a = nxp.asarray([2.1, 4], dtype=np.dtype(np.int64)) - test_duck_array_typevar(arrayapi_a) - - -def test_new_namedarray() -> None: - dtype_float = np.dtype(np.float32) - narr_float: NamedArray[Any, np.dtype[np.float32]] - narr_float = NamedArray(("x",), np.array([1.5, 3.2], dtype=dtype_float)) - assert narr_float.dtype == dtype_float - - dtype_int = np.dtype(np.int8) - narr_int: NamedArray[Any, np.dtype[np.int8]] - narr_int = narr_float._new(("x",), np.array([1, 3], dtype=dtype_int)) - assert narr_int.dtype == dtype_int - - # Test with a subclass: - class Variable( - NamedArray[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co] - ): - @overload - def _new( - self, - dims: _DimsLike | Default = ..., - data: duckarray[Any, _DType] = ..., - attrs: _AttrsLike | Default = ..., - ) -> Variable[Any, _DType]: - ... - - @overload - def _new( - self, - dims: _DimsLike | Default = ..., - data: Default = ..., - attrs: _AttrsLike | Default = ..., - ) -> Variable[_ShapeType_co, _DType_co]: - ... - - def _new( - self, - dims: _DimsLike | Default = _default, - data: duckarray[Any, _DType] | Default = _default, - attrs: _AttrsLike | Default = _default, - ) -> Variable[Any, _DType] | Variable[_ShapeType_co, _DType_co]: - dims_ = copy.copy(self._dims) if dims is _default else dims - - attrs_: Mapping[Any, Any] | None - if attrs is _default: - attrs_ = None if self._attrs is None else self._attrs.copy() - else: - attrs_ = attrs - - if data is _default: - return type(self)(dims_, copy.copy(self._data), attrs_) - else: - cls_ = cast("type[Variable[Any, _DType]]", type(self)) - return cls_(dims_, data, attrs_) - - var_float: Variable[Any, np.dtype[np.float32]] - var_float = Variable(("x",), np.array([1.5, 3.2], dtype=dtype_float)) - assert var_float.dtype == dtype_float - - var_int: Variable[Any, np.dtype[np.int8]] - var_int = var_float._new(("x",), np.array([1, 3], dtype=dtype_int)) - assert var_int.dtype == dtype_int - - -def test_replace_namedarray() -> None: - dtype_float = np.dtype(np.float32) - np_val: np.ndarray[Any, np.dtype[np.float32]] - np_val = np.array([1.5, 3.2], dtype=dtype_float) - np_val2: np.ndarray[Any, np.dtype[np.float32]] - np_val2 = 2 * np_val - - narr_float: NamedArray[Any, np.dtype[np.float32]] - narr_float = NamedArray(("x",), np_val) - assert narr_float.dtype == dtype_float - - narr_float2: NamedArray[Any, np.dtype[np.float32]] - narr_float2 = NamedArray(("x",), np_val2) - assert narr_float2.dtype == dtype_float - - # Test with a subclass: - class Variable( - NamedArray[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co] - ): - @overload - def _new( - self, - dims: _DimsLike | Default = ..., - data: duckarray[Any, _DType] = ..., - attrs: _AttrsLike | Default = ..., - ) -> Variable[Any, _DType]: - ... - - @overload - def _new( - self, - dims: _DimsLike | Default = ..., - data: Default = ..., - attrs: _AttrsLike | Default = ..., - ) -> Variable[_ShapeType_co, _DType_co]: - ... - - def _new( - self, - dims: _DimsLike | Default = _default, - data: duckarray[Any, _DType] | Default = _default, - attrs: _AttrsLike | Default = _default, - ) -> Variable[Any, _DType] | Variable[_ShapeType_co, _DType_co]: - dims_ = copy.copy(self._dims) if dims is _default else dims - - attrs_: Mapping[Any, Any] | None - if attrs is _default: - attrs_ = None if self._attrs is None else self._attrs.copy() - else: - attrs_ = attrs + actual = from_array(dims, data) - if data is _default: - return type(self)(dims_, copy.copy(self._data), attrs_) + assert np.array_equal(np.asarray(actual.data), expected) + + def test_from_array_with_masked_array(self) -> None: + masked_array: np.ndarray[Any, np.dtype[np.generic]] + masked_array = np.ma.array([1, 2, 3], mask=[False, True, False]) # type: ignore[no-untyped-call] + with pytest.raises(NotImplementedError): + from_array(("x",), masked_array) + + def test_from_array_with_0d_object(self) -> None: + data = np.empty((), dtype=object) + data[()] = (10, 12, 12) + narr = from_array((), data) + np.array_equal(np.asarray(narr.data), data) + + # TODO: Make xr.core.indexing.ExplicitlyIndexed pass as a subclass of_arrayfunction_or_api + # and remove this test. + def test_from_array_with_explicitly_indexed( + self, random_inputs: np.ndarray[Any, Any] + ) -> None: + array: CustomArray[Any, Any] + array = CustomArray(random_inputs) + output: NamedArray[Any, Any] + output = from_array(("x", "y", "z"), array) + assert isinstance(output.data, np.ndarray) + + array2: CustomArrayIndexable[Any, Any] + array2 = CustomArrayIndexable(random_inputs) + output2: NamedArray[Any, Any] + output2 = from_array(("x", "y", "z"), array2) + assert isinstance(output2.data, CustomArrayIndexable) + + def test_real_and_imag(self) -> None: + expected_real: np.ndarray[Any, np.dtype[np.float64]] + expected_real = np.arange(3, dtype=np.float64) + + expected_imag: np.ndarray[Any, np.dtype[np.float64]] + expected_imag = -np.arange(3, dtype=np.float64) + + arr: np.ndarray[Any, np.dtype[np.complex128]] + arr = expected_real + 1j * expected_imag + + named_array: NamedArray[Any, np.dtype[np.complex128]] + named_array = NamedArray(["x"], arr) + + actual_real: duckarray[Any, np.dtype[np.float64]] = named_array.real.data + assert np.array_equal(np.asarray(actual_real), expected_real) + assert actual_real.dtype == expected_real.dtype + + actual_imag: duckarray[Any, np.dtype[np.float64]] = named_array.imag.data + assert np.array_equal(np.asarray(actual_imag), expected_imag) + assert actual_imag.dtype == expected_imag.dtype + + # Additional tests as per your original class-based code + @pytest.mark.parametrize( + "data, dtype", + [ + ("foo", np.dtype("U3")), + (b"foo", np.dtype("S3")), + ], + ) + def test_from_array_0d_string(self, data: Any, dtype: DTypeLike) -> None: + named_array: NamedArray[Any, Any] + named_array = from_array([], data) + assert named_array.data == data + assert named_array.dims == () + assert named_array.sizes == {} + assert named_array.attrs == {} + assert named_array.ndim == 0 + assert named_array.size == 1 + assert named_array.dtype == dtype + + def test_from_array_0d_object(self) -> None: + named_array: NamedArray[Any, Any] + named_array = from_array([], (10, 12, 12)) + expected_data = np.empty((), dtype=object) + expected_data[()] = (10, 12, 12) + assert np.array_equal(np.asarray(named_array.data), expected_data) + + assert named_array.dims == () + assert named_array.sizes == {} + assert named_array.attrs == {} + assert named_array.ndim == 0 + assert named_array.size == 1 + assert named_array.dtype == np.dtype("O") + + def test_from_array_0d_datetime(self) -> None: + named_array: NamedArray[Any, Any] + named_array = from_array([], np.datetime64("2000-01-01")) + assert named_array.dtype == np.dtype("datetime64[D]") + + @pytest.mark.parametrize( + "timedelta, expected_dtype", + [ + (np.timedelta64(1, "D"), np.dtype("timedelta64[D]")), + (np.timedelta64(1, "s"), np.dtype("timedelta64[s]")), + (np.timedelta64(1, "m"), np.dtype("timedelta64[m]")), + (np.timedelta64(1, "h"), np.dtype("timedelta64[h]")), + (np.timedelta64(1, "us"), np.dtype("timedelta64[us]")), + (np.timedelta64(1, "ns"), np.dtype("timedelta64[ns]")), + (np.timedelta64(1, "ps"), np.dtype("timedelta64[ps]")), + (np.timedelta64(1, "fs"), np.dtype("timedelta64[fs]")), + (np.timedelta64(1, "as"), np.dtype("timedelta64[as]")), + ], + ) + def test_from_array_0d_timedelta( + self, timedelta: np.timedelta64, expected_dtype: np.dtype[np.timedelta64] + ) -> None: + named_array: NamedArray[Any, Any] + named_array = from_array([], timedelta) + assert named_array.dtype == expected_dtype + assert named_array.data == timedelta + + @pytest.mark.parametrize( + "dims, data_shape, new_dims, raises", + [ + (["x", "y", "z"], (2, 3, 4), ["a", "b", "c"], False), + (["x", "y", "z"], (2, 3, 4), ["a", "b"], True), + (["x", "y", "z"], (2, 4, 5), ["a", "b", "c", "d"], True), + ([], [], (), False), + ([], [], ("x",), True), + ], + ) + def test_dims_setter( + self, dims: Any, data_shape: Any, new_dims: Any, raises: bool + ) -> None: + named_array: NamedArray[Any, Any] + named_array = NamedArray(dims, np.asarray(np.random.random(data_shape))) + assert named_array.dims == tuple(dims) + if raises: + with pytest.raises(ValueError): + named_array.dims = new_dims + else: + named_array.dims = new_dims + assert named_array.dims == tuple(new_dims) + + def test_duck_array_class( + self, + ) -> None: + def test_duck_array_typevar( + a: duckarray[Any, _DType] + ) -> duckarray[Any, _DType]: + # Mypy checks a is valid: + b: duckarray[Any, _DType] = a + + # Runtime check if valid: + if isinstance(b, _arrayfunction_or_api): + return b else: - cls_ = cast("type[Variable[Any, _DType]]", type(self)) - return cls_(dims_, data, attrs_) - - var_float: Variable[Any, np.dtype[np.float32]] - var_float = Variable(("x",), np_val) - assert var_float.dtype == dtype_float - - var_float2: Variable[Any, np.dtype[np.float32]] - var_float2 = var_float._replace(("x",), np_val2) - assert var_float2.dtype == dtype_float + raise TypeError( + f"a ({type(a)}) is not a valid _arrayfunction or _arrayapi" + ) + + numpy_a: NDArray[np.int64] + numpy_a = np.array([2.1, 4], dtype=np.dtype(np.int64)) + test_duck_array_typevar(numpy_a) + + masked_a: np.ma.MaskedArray[Any, np.dtype[np.int64]] + masked_a = np.ma.asarray([2.1, 4], dtype=np.dtype(np.int64)) # type: ignore[no-untyped-call] + test_duck_array_typevar(masked_a) + + custom_a: CustomArrayIndexable[Any, np.dtype[np.int64]] + custom_a = CustomArrayIndexable(numpy_a) + test_duck_array_typevar(custom_a) + + # Test numpy's array api: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + r"The numpy.array_api submodule is still experimental", + category=UserWarning, + ) + import numpy.array_api as nxp + + # TODO: nxp doesn't use dtype typevars, so can only use Any for the moment: + arrayapi_a: duckarray[Any, Any] # duckarray[Any, np.dtype[np.int64]] + arrayapi_a = nxp.asarray([2.1, 4], dtype=np.dtype(np.int64)) + test_duck_array_typevar(arrayapi_a) + + def test_new_namedarray(self) -> None: + dtype_float = np.dtype(np.float32) + narr_float: NamedArray[Any, np.dtype[np.float32]] + narr_float = NamedArray(("x",), np.array([1.5, 3.2], dtype=dtype_float)) + assert narr_float.dtype == dtype_float + + dtype_int = np.dtype(np.int8) + narr_int: NamedArray[Any, np.dtype[np.int8]] + narr_int = narr_float._new(("x",), np.array([1, 3], dtype=dtype_int)) + assert narr_int.dtype == dtype_int + + # Test with a subclass: + class Variable( + NamedArray[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co] + ): + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: duckarray[Any, _DType] = ..., + attrs: _AttrsLike | Default = ..., + ) -> Variable[Any, _DType]: + ... + + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: Default = ..., + attrs: _AttrsLike | Default = ..., + ) -> Variable[_ShapeType_co, _DType_co]: + ... + + def _new( + self, + dims: _DimsLike | Default = _default, + data: duckarray[Any, _DType] | Default = _default, + attrs: _AttrsLike | Default = _default, + ) -> Variable[Any, _DType] | Variable[_ShapeType_co, _DType_co]: + dims_ = copy.copy(self._dims) if dims is _default else dims + + attrs_: Mapping[Any, Any] | None + if attrs is _default: + attrs_ = None if self._attrs is None else self._attrs.copy() + else: + attrs_ = attrs + + if data is _default: + return type(self)(dims_, copy.copy(self._data), attrs_) + else: + cls_ = cast("type[Variable[Any, _DType]]", type(self)) + return cls_(dims_, data, attrs_) + + var_float: Variable[Any, np.dtype[np.float32]] + var_float = Variable(("x",), np.array([1.5, 3.2], dtype=dtype_float)) + assert var_float.dtype == dtype_float + + var_int: Variable[Any, np.dtype[np.int8]] + var_int = var_float._new(("x",), np.array([1, 3], dtype=dtype_int)) + assert var_int.dtype == dtype_int + + def test_replace_namedarray(self) -> None: + dtype_float = np.dtype(np.float32) + np_val: np.ndarray[Any, np.dtype[np.float32]] + np_val = np.array([1.5, 3.2], dtype=dtype_float) + np_val2: np.ndarray[Any, np.dtype[np.float32]] + np_val2 = 2 * np_val + + narr_float: NamedArray[Any, np.dtype[np.float32]] + narr_float = NamedArray(("x",), np_val) + assert narr_float.dtype == dtype_float + + narr_float2: NamedArray[Any, np.dtype[np.float32]] + narr_float2 = NamedArray(("x",), np_val2) + assert narr_float2.dtype == dtype_float + + # Test with a subclass: + class Variable( + NamedArray[_ShapeType_co, _DType_co], Generic[_ShapeType_co, _DType_co] + ): + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: duckarray[Any, _DType] = ..., + attrs: _AttrsLike | Default = ..., + ) -> Variable[Any, _DType]: + ... + + @overload + def _new( + self, + dims: _DimsLike | Default = ..., + data: Default = ..., + attrs: _AttrsLike | Default = ..., + ) -> Variable[_ShapeType_co, _DType_co]: + ... + + def _new( + self, + dims: _DimsLike | Default = _default, + data: duckarray[Any, _DType] | Default = _default, + attrs: _AttrsLike | Default = _default, + ) -> Variable[Any, _DType] | Variable[_ShapeType_co, _DType_co]: + dims_ = copy.copy(self._dims) if dims is _default else dims + + attrs_: Mapping[Any, Any] | None + if attrs is _default: + attrs_ = None if self._attrs is None else self._attrs.copy() + else: + attrs_ = attrs + + if data is _default: + return type(self)(dims_, copy.copy(self._data), attrs_) + else: + cls_ = cast("type[Variable[Any, _DType]]", type(self)) + return cls_(dims_, data, attrs_) + + var_float: Variable[Any, np.dtype[np.float32]] + var_float = Variable(("x",), np_val) + assert var_float.dtype == dtype_float + + var_float2: Variable[Any, np.dtype[np.float32]] + var_float2 = var_float._replace(("x",), np_val2) + assert var_float2.dtype == dtype_float diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index d91cf85e4eb..0bea3f63673 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from abc import ABC, abstractmethod +from abc import ABC from copy import copy, deepcopy from datetime import datetime, timedelta from textwrap import dedent @@ -46,6 +46,7 @@ requires_sparse, source_ndarray, ) +from xarray.tests.test_namedarray import NamedArraySubclassobjects dask_array_type = array_type("dask") @@ -63,34 +64,11 @@ def var(): return Variable(dims=list("xyz"), data=np.random.rand(3, 4, 5)) -class VariableSubclassobjects(ABC): - @abstractmethod - def cls(self, *args, **kwargs) -> Variable: - raise NotImplementedError - - def test_properties(self): - data = 0.5 * np.arange(10) - v = self.cls(["time"], data, {"foo": "bar"}) - assert v.dims == ("time",) - assert_array_equal(v.values, data) - assert v.dtype == float - assert v.shape == (10,) - assert v.size == 10 - assert v.sizes == {"time": 10} - assert v.nbytes == 80 - assert v.ndim == 1 - assert len(v) == 10 - assert v.attrs == {"foo": "bar"} - - def test_attrs(self): - v = self.cls(["time"], 0.5 * np.arange(10)) - assert v.attrs == {} - attrs = {"foo": "bar"} - v.attrs = attrs - assert v.attrs == attrs - assert isinstance(v.attrs, dict) - v.attrs["foo"] = "baz" - assert v.attrs["foo"] == "baz" +class VariableSubclassobjects(NamedArraySubclassobjects, ABC): + @pytest.fixture + def target(self, data): + data = 0.5 * np.arange(10).reshape(2, 5) + return Variable(["x", "y"], data) def test_getitem_dict(self): v = self.cls(["x"], np.random.randn(5)) @@ -368,7 +346,7 @@ def test_1d_math(self, dtype: np.typing.DTypeLike) -> None: assert_array_equal(v >> 2, x >> 2) # binary ops with numpy arrays assert_array_equal((v * x).values, x**2) - assert_array_equal((x * v).values, x**2) # type: ignore[attr-defined] # TODO: Fix mypy thinking numpy takes priority, GH7780 + assert_array_equal((x * v).values, x**2) assert_array_equal(v - y, v - 1) assert_array_equal(y - v, 1 - v) if dtype is int: @@ -1065,9 +1043,8 @@ def cls(self, *args, **kwargs) -> Variable: def setup(self): self.d = np.random.random((10, 3)).astype(np.float64) - def test_data_and_values(self): + def test_values(self): v = Variable(["time", "x"], self.d) - assert_array_equal(v.data, self.d) assert_array_equal(v.values, self.d) assert source_ndarray(v.values) is self.d with pytest.raises(ValueError): @@ -1076,9 +1053,6 @@ def test_data_and_values(self): d2 = np.random.random((10, 3)) v.values = d2 assert source_ndarray(v.values) is d2 - d3 = np.random.random((10, 3)) - v.data = d3 - assert source_ndarray(v.data) is d3 def test_numpy_same_methods(self): v = Variable([], np.float32(0.0)) From e7e8c38566c011b50a8b1980c2e563a1db3cbed5 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Tue, 28 Nov 2023 13:04:47 -0800 Subject: [PATCH 13/14] Start renaming `dims` to `dim` (#8487) * Start renaming `dims` to `dim` Begins the process of #6646. I don't think it's feasible / enjoyable to do this for everything at once, so I would suggest we do it gradually, while keeping the warnings quite quiet, so by the time we convert to louder warnings, users can do a find/replace easily. * No deprecation for internal methods * Simplify typing --- doc/whats-new.rst | 7 ++++++ xarray/core/alignment.py | 14 +++++------ xarray/core/computation.py | 28 +++++++++++---------- xarray/core/dataarray.py | 17 +++++++------ xarray/core/variable.py | 6 ++--- xarray/core/weighted.py | 2 +- xarray/tests/test_computation.py | 40 +++++++++++++++--------------- xarray/tests/test_dataarray.py | 6 ++--- xarray/util/deprecation_helpers.py | 27 ++++++++++++++++++++ 9 files changed, 92 insertions(+), 55 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 71cbc1a08ee..92048e02837 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,6 +38,13 @@ Breaking changes Deprecations ~~~~~~~~~~~~ +- As part of an effort to standardize the API, we're renaming the ``dims`` + keyword arg to ``dim`` for the minority of functions which current use + ``dims``. This started with :py:func:`xarray.dot` & :py:meth:`DataArray.dot` + and we'll gradually roll this out across all functions. The warnings are + currently ``PendingDeprecationWarning``, which are silenced by default. We'll + convert these to ``DeprecationWarning`` in a future release. + By `Maximilian Roos `_. Bug fixes ~~~~~~~~~ diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 732ec5d3ea6..041fe63a9f3 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -324,7 +324,7 @@ def assert_no_index_conflict(self) -> None: "- they may be used to reindex data along common dimensions" ) - def _need_reindex(self, dims, cmp_indexes) -> bool: + def _need_reindex(self, dim, cmp_indexes) -> bool: """Whether or not we need to reindex variables for a set of matching indexes. @@ -340,14 +340,14 @@ def _need_reindex(self, dims, cmp_indexes) -> bool: return True unindexed_dims_sizes = {} - for dim in dims: - if dim in self.unindexed_dim_sizes: - sizes = self.unindexed_dim_sizes[dim] + for d in dim: + if d in self.unindexed_dim_sizes: + sizes = self.unindexed_dim_sizes[d] if len(sizes) > 1: # reindex if different sizes are found for unindexed dims return True else: - unindexed_dims_sizes[dim] = next(iter(sizes)) + unindexed_dims_sizes[d] = next(iter(sizes)) if unindexed_dims_sizes: indexed_dims_sizes = {} @@ -356,8 +356,8 @@ def _need_reindex(self, dims, cmp_indexes) -> bool: for var in index_vars.values(): indexed_dims_sizes.update(var.sizes) - for dim, size in unindexed_dims_sizes.items(): - if indexed_dims_sizes.get(dim, -1) != size: + for d, size in unindexed_dims_sizes.items(): + if indexed_dims_sizes.get(d, -1) != size: # reindex if unindexed dimension size doesn't match return True diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 0c5c9d6d5cb..ed2c733d4ca 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -26,6 +26,7 @@ from xarray.core.types import Dims, T_DataArray from xarray.core.utils import is_dict_like, is_scalar from xarray.core.variable import Variable +from xarray.util.deprecation_helpers import deprecate_dims if TYPE_CHECKING: from xarray.core.coordinates import Coordinates @@ -1691,9 +1692,10 @@ def cross( return c +@deprecate_dims def dot( *arrays, - dims: Dims = None, + dim: Dims = None, **kwargs: Any, ): """Generalized dot product for xarray objects. Like ``np.einsum``, but @@ -1703,7 +1705,7 @@ def dot( ---------- *arrays : DataArray or Variable Arrays to compute. - dims : str, iterable of hashable, "..." or None, optional + dim : str, iterable of hashable, "..." or None, optional Which dimensions to sum over. Ellipsis ('...') sums over all dimensions. If not specified, then all the common dimensions are summed over. **kwargs : dict @@ -1756,18 +1758,18 @@ def dot( [3, 4, 5]]) Dimensions without coordinates: c, d - >>> xr.dot(da_a, da_b, dims=["a", "b"]) + >>> xr.dot(da_a, da_b, dim=["a", "b"]) array([110, 125]) Dimensions without coordinates: c - >>> xr.dot(da_a, da_b, dims=["a"]) + >>> xr.dot(da_a, da_b, dim=["a"]) array([[40, 46], [70, 79]]) Dimensions without coordinates: b, c - >>> xr.dot(da_a, da_b, da_c, dims=["b", "c"]) + >>> xr.dot(da_a, da_b, da_c, dim=["b", "c"]) array([[ 9, 14, 19], [ 93, 150, 207], @@ -1779,7 +1781,7 @@ def dot( array([110, 125]) Dimensions without coordinates: c - >>> xr.dot(da_a, da_b, dims=...) + >>> xr.dot(da_a, da_b, dim=...) array(235) """ @@ -1803,18 +1805,18 @@ def dot( einsum_axes = "abcdefghijklmnopqrstuvwxyz" dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)} - if dims is ...: - dims = all_dims - elif isinstance(dims, str): - dims = (dims,) - elif dims is None: + if dim is ...: + dim = all_dims + elif isinstance(dim, str): + dim = (dim,) + elif dim is None: # find dimensions that occur more than one times dim_counts: Counter = Counter() for arr in arrays: dim_counts.update(arr.dims) - dims = tuple(d for d, c in dim_counts.items() if c > 1) + dim = tuple(d for d, c in dim_counts.items() if c > 1) - dot_dims: set[Hashable] = set(dims) + dot_dims: set[Hashable] = set(dim) # dimensions to be parallelized broadcast_dims = common_dims - dot_dims diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index b417470fdc0..47708cfb581 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -65,7 +65,7 @@ ) from xarray.plot.accessor import DataArrayPlotAccessor from xarray.plot.utils import _get_units_from_attrs -from xarray.util.deprecation_helpers import _deprecate_positional_args +from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims if TYPE_CHECKING: from typing import TypeVar, Union @@ -115,14 +115,14 @@ T_XarrayOther = TypeVar("T_XarrayOther", bound=Union["DataArray", Dataset]) -def _check_coords_dims(shape, coords, dims): - sizes = dict(zip(dims, shape)) +def _check_coords_dims(shape, coords, dim): + sizes = dict(zip(dim, shape)) for k, v in coords.items(): - if any(d not in dims for d in v.dims): + if any(d not in dim for d in v.dims): raise ValueError( f"coordinate {k} has dimensions {v.dims}, but these " "are not a subset of the DataArray " - f"dimensions {dims}" + f"dimensions {dim}" ) for d, s in v.sizes.items(): @@ -4895,10 +4895,11 @@ def imag(self) -> Self: """ return self._replace(self.variable.imag) + @deprecate_dims def dot( self, other: T_Xarray, - dims: Dims = None, + dim: Dims = None, ) -> T_Xarray: """Perform dot product of two DataArrays along their shared dims. @@ -4908,7 +4909,7 @@ def dot( ---------- other : DataArray The other array with which the dot product is performed. - dims : ..., str, Iterable of Hashable or None, optional + dim : ..., str, Iterable of Hashable or None, optional Which dimensions to sum over. Ellipsis (`...`) sums over all dimensions. If not specified, then all the common dimensions are summed over. @@ -4947,7 +4948,7 @@ def dot( if not isinstance(other, DataArray): raise TypeError("dot only operates on DataArrays.") - return computation.dot(self, other, dims=dims) + return computation.dot(self, other, dim=dim) def sortby( self, diff --git a/xarray/core/variable.py b/xarray/core/variable.py index c2133d55aeb..39a947e6264 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1541,15 +1541,15 @@ def stack(self, dimensions=None, **dimensions_kwargs): result = result._stack_once(dims, new_dim) return result - def _unstack_once_full(self, dims: Mapping[Any, int], old_dim: Hashable) -> Self: + def _unstack_once_full(self, dim: Mapping[Any, int], old_dim: Hashable) -> Self: """ Unstacks the variable without needing an index. Unlike `_unstack_once`, this function requires the existing dimension to contain the full product of the new dimensions. """ - new_dim_names = tuple(dims.keys()) - new_dim_sizes = tuple(dims.values()) + new_dim_names = tuple(dim.keys()) + new_dim_sizes = tuple(dim.values()) if old_dim not in self.dims: raise ValueError(f"invalid existing dimension: {old_dim}") diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 28740a99020..53ff6db5f28 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -228,7 +228,7 @@ def _reduce( # `dot` does not broadcast arrays, so this avoids creating a large # DataArray (if `weights` has additional dimensions) - return dot(da, weights, dims=dim) + return dot(da, weights, dim=dim) def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray: """Calculate the sum of weights, accounting for missing values""" diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 425673dc40f..396507652c6 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1936,7 +1936,7 @@ def test_dot(use_dask: bool) -> None: da_a = da_a.chunk({"a": 3}) da_b = da_b.chunk({"a": 3}) da_c = da_c.chunk({"c": 3}) - actual = xr.dot(da_a, da_b, dims=["a", "b"]) + actual = xr.dot(da_a, da_b, dim=["a", "b"]) assert actual.dims == ("c",) assert (actual.data == np.einsum("ij,ijk->k", a, b)).all() assert isinstance(actual.variable.data, type(da_a.variable.data)) @@ -1960,33 +1960,33 @@ def test_dot(use_dask: bool) -> None: if use_dask: da_a = da_a.chunk({"a": 3}) da_b = da_b.chunk({"a": 3}) - actual = xr.dot(da_a, da_b, dims=["b"]) + actual = xr.dot(da_a, da_b, dim=["b"]) assert actual.dims == ("a", "c") assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() assert isinstance(actual.variable.data, type(da_a.variable.data)) - actual = xr.dot(da_a, da_b, dims=["b"]) + actual = xr.dot(da_a, da_b, dim=["b"]) assert actual.dims == ("a", "c") assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() - actual = xr.dot(da_a, da_b, dims="b") + actual = xr.dot(da_a, da_b, dim="b") assert actual.dims == ("a", "c") assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() - actual = xr.dot(da_a, da_b, dims="a") + actual = xr.dot(da_a, da_b, dim="a") assert actual.dims == ("b", "c") assert (actual.data == np.einsum("ij,ijk->jk", a, b)).all() - actual = xr.dot(da_a, da_b, dims="c") + actual = xr.dot(da_a, da_b, dim="c") assert actual.dims == ("a", "b") assert (actual.data == np.einsum("ij,ijk->ij", a, b)).all() - actual = xr.dot(da_a, da_b, da_c, dims=["a", "b"]) + actual = xr.dot(da_a, da_b, da_c, dim=["a", "b"]) assert actual.dims == ("c", "e") assert (actual.data == np.einsum("ij,ijk,kl->kl ", a, b, c)).all() # should work with tuple - actual = xr.dot(da_a, da_b, dims=("c",)) + actual = xr.dot(da_a, da_b, dim=("c",)) assert actual.dims == ("a", "b") assert (actual.data == np.einsum("ij,ijk->ij", a, b)).all() @@ -1996,47 +1996,47 @@ def test_dot(use_dask: bool) -> None: assert (actual.data == np.einsum("ij,ijk,kl->l ", a, b, c)).all() # 1 array summation - actual = xr.dot(da_a, dims="a") + actual = xr.dot(da_a, dim="a") assert actual.dims == ("b",) assert (actual.data == np.einsum("ij->j ", a)).all() # empty dim - actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims="a") + actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dim="a") assert actual.dims == ("b",) assert (actual.data == np.zeros(actual.shape)).all() # Ellipsis (...) sums over all dimensions - actual = xr.dot(da_a, da_b, dims=...) + actual = xr.dot(da_a, da_b, dim=...) assert actual.dims == () assert (actual.data == np.einsum("ij,ijk->", a, b)).all() - actual = xr.dot(da_a, da_b, da_c, dims=...) + actual = xr.dot(da_a, da_b, da_c, dim=...) assert actual.dims == () assert (actual.data == np.einsum("ij,ijk,kl-> ", a, b, c)).all() - actual = xr.dot(da_a, dims=...) + actual = xr.dot(da_a, dim=...) assert actual.dims == () assert (actual.data == np.einsum("ij-> ", a)).all() - actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dims=...) + actual = xr.dot(da_a.sel(a=[]), da_a.sel(a=[]), dim=...) assert actual.dims == () assert (actual.data == np.zeros(actual.shape)).all() # Invalid cases if not use_dask: with pytest.raises(TypeError): - xr.dot(da_a, dims="a", invalid=None) + xr.dot(da_a, dim="a", invalid=None) with pytest.raises(TypeError): - xr.dot(da_a.to_dataset(name="da"), dims="a") + xr.dot(da_a.to_dataset(name="da"), dim="a") with pytest.raises(TypeError): - xr.dot(dims="a") + xr.dot(dim="a") # einsum parameters - actual = xr.dot(da_a, da_b, dims=["b"], order="C") + actual = xr.dot(da_a, da_b, dim=["b"], order="C") assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() assert actual.values.flags["C_CONTIGUOUS"] assert not actual.values.flags["F_CONTIGUOUS"] - actual = xr.dot(da_a, da_b, dims=["b"], order="F") + actual = xr.dot(da_a, da_b, dim=["b"], order="F") assert (actual.data == np.einsum("ij,ijk->ik", a, b)).all() # dask converts Fortran arrays to C order when merging the final array if not use_dask: @@ -2078,7 +2078,7 @@ def test_dot_align_coords(use_dask: bool) -> None: expected = (da_a * da_b).sum(["a", "b"]) xr.testing.assert_allclose(expected, actual) - actual = xr.dot(da_a, da_b, dims=...) + actual = xr.dot(da_a, da_b, dim=...) expected = (da_a * da_b).sum() xr.testing.assert_allclose(expected, actual) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 44b9790f0b7..f9547f3afa2 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -3964,13 +3964,13 @@ def test_dot(self) -> None: assert_equal(expected3, actual3) # Ellipsis: all dims are shared - actual4 = da.dot(da, dims=...) + actual4 = da.dot(da, dim=...) expected4 = da.dot(da) assert_equal(expected4, actual4) # Ellipsis: not all dims are shared - actual5 = da.dot(dm3, dims=...) - expected5 = da.dot(dm3, dims=("j", "x", "y", "z")) + actual5 = da.dot(dm3, dim=...) + expected5 = da.dot(dm3, dim=("j", "x", "y", "z")) assert_equal(expected5, actual5) with pytest.raises(NotImplementedError): diff --git a/xarray/util/deprecation_helpers.py b/xarray/util/deprecation_helpers.py index 7b4cf901aa1..c620e45574e 100644 --- a/xarray/util/deprecation_helpers.py +++ b/xarray/util/deprecation_helpers.py @@ -36,6 +36,8 @@ from functools import wraps from typing import Callable, TypeVar +from xarray.core.utils import emit_user_level_warning + T = TypeVar("T", bound=Callable) POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD @@ -115,3 +117,28 @@ def inner(*args, **kwargs): return inner return _decorator + + +def deprecate_dims(func: T) -> T: + """ + For functions that previously took `dims` as a kwarg, and have now transitioned to + `dim`. This decorator will issue a warning if `dims` is passed while forwarding it + to `dim`. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + if "dims" in kwargs: + emit_user_level_warning( + "The `dims` argument has been renamed to `dim`, and will be removed " + "in the future. This renaming is taking place throughout xarray over the " + "next few releases.", + # Upgrade to `DeprecationWarning` in the future, when the renaming is complete. + PendingDeprecationWarning, + ) + kwargs["dim"] = kwargs.pop("dims") + return func(*args, **kwargs) + + # We're quite confident we're just returning `T` from this function, so it's fine to ignore typing + # within the function. + return wrapper # type: ignore From dc0931ad05f631135baa9889bdceeb15e2fa727c Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> Date: Tue, 28 Nov 2023 14:19:00 -0800 Subject: [PATCH 14/14] Raise an informative error message when object array has mixed types (#4700) Co-authored-by: Mathias Hauser Co-authored-by: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> --- xarray/conventions.py | 24 ++++++++++++++++++++---- xarray/tests/test_conventions.py | 12 ++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/xarray/conventions.py b/xarray/conventions.py index 75f816e6cb4..8c7d6be2309 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -52,16 +52,32 @@ def _var_as_tuple(var: Variable) -> T_VarTuple: return var.dims, var.data, var.attrs.copy(), var.encoding.copy() -def _infer_dtype(array, name: T_Name = None) -> np.dtype: - """Given an object array with no missing values, infer its dtype from its - first element - """ +def _infer_dtype(array, name=None): + """Given an object array with no missing values, infer its dtype from all elements.""" if array.dtype.kind != "O": raise TypeError("infer_type must be called on a dtype=object array") if array.size == 0: return np.dtype(float) + native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel())) + if len(native_dtypes) > 1 and native_dtypes != {bytes, str}: + raise ValueError( + "unable to infer dtype on variable {!r}; object array " + "contains mixed native types: {}".format( + name, ", ".join(x.__name__ for x in native_dtypes) + ) + ) + + native_dtypes = set(np.vectorize(type, otypes=[object])(array.ravel())) + if len(native_dtypes) > 1 and native_dtypes != {bytes, str}: + raise ValueError( + "unable to infer dtype on variable {!r}; object array " + "contains mixed native types: {}".format( + name, ", ".join(x.__name__ for x in native_dtypes) + ) + ) + element = array[(0,) * array.ndim] # We use the base types to avoid subclasses of bytes and str (which might # not play nice with e.g. hdf5 datatypes), such as those from numpy diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index d6d1303a696..be6e949edf8 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -495,6 +495,18 @@ def test_encoding_kwarg_fixed_width_string(self) -> None: pass +@pytest.mark.parametrize( + "data", + [ + np.array([["ab", "cdef", b"X"], [1, 2, "c"]], dtype=object), + np.array([["x", 1], ["y", 2]], dtype="object"), + ], +) +def test_infer_dtype_error_on_mixed_types(data): + with pytest.raises(ValueError, match="unable to infer dtype on variable"): + conventions._infer_dtype(data, "test") + + class TestDecodeCFVariableWithArrayUnits: def test_decode_cf_variable_with_array_units(self) -> None: v = Variable(["t"], [1, 2, 3], {"units": np.array(["foobar"], dtype=object)})