diff --git a/doc/whats-new.rst b/doc/whats-new.rst index cbb63feda96..3287374d5c3 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,7 +34,12 @@ Deprecations Bug fixes ~~~~~~~~~ - +- Require to explicitly defining optional dimensions such as hue + and markersize for scatter plots. (:issue:`7314`, :pull:`7277`). + By `Jimmy Westling `_. +- Fix matplotlib raising a UserWarning when plotting a scatter plot + with an unfilled marker (:issue:`7313`, :pull:`7318`). + By `Jimmy Westling `_. Documentation ~~~~~~~~~~~~~ @@ -194,9 +199,6 @@ Bug fixes By `Michael Niklas `_. - Fix static typing of :py:meth:`xr.polyval` (:issue:`7312`, :pull:`7315`). By `Michael Niklas `_. -- Fix matplotlib raising a UserWarning when plotting a scatter plot - with an unfilled marker (:issue:`7313`, :pull:`7318`). - By `Jimmy Westling `_. - Fix multiple reads on fsspec S3 files by resetting file pointer to 0 when reading file streams (:issue:`6813`, :pull:`7304`). By `David Hoese `_ and `Wei Ji Leong `_. - Fix :py:meth:`Dataset.assign_coords` resetting all dimension coordinates to default (pandas) index (:issue:`7346`, :pull:`7347`). diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index e3910c24ce3..4c77539b5bb 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -19,6 +19,7 @@ _assert_valid_xy, _determine_guide, _ensure_plottable, + _guess_coords_to_plot, _infer_interval_breaks, _infer_xy_labels, _Normalize, @@ -142,48 +143,45 @@ def _infer_line_data( return xplt, yplt, hueplt, huelabel -def _infer_plot_dims( - darray: DataArray, - dims_plot: MutableMapping[str, Hashable], - default_guess: Iterable[str] = ("x", "hue", "size"), -) -> MutableMapping[str, Hashable]: +def _prepare_plot1d_data( + darray: T_DataArray, + coords_to_plot: MutableMapping[str, Hashable], + plotfunc_name: str | None = None, + _is_facetgrid: bool = False, +) -> dict[str, T_DataArray]: """ - Guess what dims to plot if some of the values in dims_plot are None which - happens when the user has not defined all available ways of visualizing - the data. + Prepare data for usage with plt.scatter. Parameters ---------- - darray : DataArray - The DataArray to check. - dims_plot : T_DimsPlot - Dims defined by the user to plot. - default_guess : Iterable[str], optional - Default values and order to retrieve dims if values in dims_plot is - missing, default: ("x", "hue", "size"). - """ - dims_plot_exist = {k: v for k, v in dims_plot.items() if v is not None} - dims_avail = tuple(v for v in darray.dims if v not in dims_plot_exist.values()) - - # If dims_plot[k] isn't defined then fill with one of the available dims: - for k, v in zip(default_guess, dims_avail): - if dims_plot.get(k, None) is None: - dims_plot[k] = v - - for k, v in dims_plot.items(): - _assert_valid_xy(darray, v, k) - - return dims_plot - + darray : T_DataArray + Base DataArray. + coords_to_plot : MutableMapping[str, Hashable] + Coords that will be plotted. + plotfunc_name : str | None + Name of the plotting function that will be used. -def _infer_line_data2( - darray: T_DataArray, - dims_plot: MutableMapping[str, Hashable], - plotfunc_name: None | str = None, -) -> dict[str, T_DataArray]: - # Guess what dims to use if some of the values in plot_dims are None: - dims_plot = _infer_plot_dims(darray, dims_plot) + Returns + ------- + plts : dict[str, T_DataArray] + Dict of DataArrays that will be sent to matplotlib. + Examples + -------- + >>> # Make sure int coords are plotted: + >>> a = xr.DataArray( + ... data=[1, 2], + ... coords={1: ("x", [0, 1], {"units": "s"})}, + ... dims=("x",), + ... name="a", + ... ) + >>> plts = xr.plot.dataarray_plot._prepare_plot1d_data( + ... a, coords_to_plot={"x": 1, "z": None, "hue": None, "size": None} + ... ) + >>> # Check which coords to plot: + >>> print({k: v.name for k, v in plts.items()}) + {'y': 'a', 'x': 1} + """ # If there are more than 1 dimension in the array than stack all the # dimensions so the plotter can plot anything: if darray.ndim > 1: @@ -193,11 +191,11 @@ def _infer_line_data2( dims_T = [] if np.issubdtype(darray.dtype, np.floating): for v in ["z", "x"]: - dim = dims_plot.get(v, None) + dim = coords_to_plot.get(v, None) if (dim is not None) and (dim in darray.dims): darray_nan = np.nan * darray.isel({dim: -1}) darray = concat([darray, darray_nan], dim=dim) - dims_T.append(dims_plot[v]) + dims_T.append(coords_to_plot[v]) # Lines should never connect to the same coordinate when stacked, # transpose to avoid this as much as possible: @@ -207,11 +205,13 @@ def _infer_line_data2( darray = darray.stack(_stacked_dim=darray.dims) # Broadcast together all the chosen variables: - out = dict(y=darray) - out.update({k: darray[v] for k, v in dims_plot.items() if v is not None}) - out = dict(zip(out.keys(), broadcast(*(out.values())))) + plts = dict(y=darray) + plts.update( + {k: darray.coords[v] for k, v in coords_to_plot.items() if v is not None} + ) + plts = dict(zip(plts.keys(), broadcast(*(plts.values())))) - return out + return plts # return type is Any due to the many different possibilities @@ -938,15 +938,20 @@ def newplotfunc( _is_facetgrid = kwargs.pop("_is_facetgrid", False) if plotfunc.__name__ == "scatter": - size_ = markersize + size_ = kwargs.pop("_size", markersize) size_r = _MARKERSIZE_RANGE else: - size_ = linewidth + size_ = kwargs.pop("_size", linewidth) size_r = _LINEWIDTH_RANGE # Get data to plot: - dims_plot = dict(x=x, z=z, hue=hue, size=size_) - plts = _infer_line_data2(darray, dims_plot, plotfunc.__name__) + coords_to_plot: MutableMapping[str, Hashable | None] = dict( + x=x, z=z, hue=hue, size=size_ + ) + if not _is_facetgrid: + # Guess what coords to use if some of the values in coords_to_plot are None: + coords_to_plot = _guess_coords_to_plot(darray, coords_to_plot, kwargs) + plts = _prepare_plot1d_data(darray, coords_to_plot, plotfunc.__name__) xplt = plts.pop("x", None) yplt = plts.pop("y", None) zplt = plts.pop("z", None) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 4d90c160400..93a328836d0 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -3,7 +3,7 @@ import functools import itertools import warnings -from collections.abc import Hashable, Iterable +from collections.abc import Hashable, Iterable, MutableMapping from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, cast import numpy as np @@ -16,6 +16,7 @@ _add_legend, _determine_guide, _get_nice_quiver_magnitude, + _guess_coords_to_plot, _infer_xy_labels, _Normalize, _parse_size, @@ -383,6 +384,11 @@ def map_plot1d( func: Callable, x: Hashable | None, y: Hashable | None, + *, + z: Hashable | None = None, + hue: Hashable | None = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, **kwargs: Any, ) -> T_FacetGrid: """ @@ -415,13 +421,25 @@ def map_plot1d( if kwargs.get("cbar_ax", None) is not None: raise ValueError("cbar_ax not supported by FacetGrid.") + if func.__name__ == "scatter": + size_ = kwargs.pop("_size", markersize) + size_r = _MARKERSIZE_RANGE + else: + size_ = kwargs.pop("_size", linewidth) + size_r = _LINEWIDTH_RANGE + + # Guess what coords to use if some of the values in coords_to_plot are None: + coords_to_plot: MutableMapping[str, Hashable | None] = dict( + x=x, z=z, hue=hue, size=size_ + ) + coords_to_plot = _guess_coords_to_plot(self.data, coords_to_plot, kwargs) + # Handle hues: - hue = kwargs.get("hue", None) - hueplt = self.data[hue] if hue else self.data + hue = coords_to_plot["hue"] + hueplt = self.data.coords[hue] if hue else None # TODO: _infer_line_data2 ? hueplt_norm = _Normalize(hueplt) self._hue_var = hueplt cbar_kwargs = kwargs.pop("cbar_kwargs", {}) - if hueplt_norm.data is not None: if not hueplt_norm.data_is_numeric: # TODO: Ticks seems a little too hardcoded, since it will always @@ -441,16 +459,11 @@ def map_plot1d( cmap_params = {} # Handle sizes: - _size_r = _MARKERSIZE_RANGE if func.__name__ == "scatter" else _LINEWIDTH_RANGE - for _size in ("markersize", "linewidth"): - size = kwargs.get(_size, None) - - sizeplt = self.data[size] if size else None - sizeplt_norm = _Normalize(data=sizeplt, width=_size_r) - if size: - self.data[size] = sizeplt_norm.values - kwargs.update(**{_size: size}) - break + size_ = coords_to_plot["size"] + sizeplt = self.data.coords[size_] if size_ else None + sizeplt_norm = _Normalize(data=sizeplt, width=size_r) + if sizeplt_norm.data is not None: + self.data[size_] = sizeplt_norm.values # Add kwargs that are sent to the plotting function, # order is important ??? func_kwargs = { @@ -504,6 +517,8 @@ def map_plot1d( x=x, y=y, ax=ax, + hue=hue, + _size=size_, **func_kwargs, _is_facetgrid=True, ) diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 7afcaed142d..b5d5a122c7a 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -3,7 +3,7 @@ import itertools import textwrap import warnings -from collections.abc import Hashable, Iterable, Mapping, Sequence +from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence from datetime import datetime from inspect import getfullargspec from typing import TYPE_CHECKING, Any, Callable, overload @@ -1735,3 +1735,92 @@ def _add_legend( _adjust_legend_subtitles(legend) return legend + + +def _guess_coords_to_plot( + darray: DataArray, + coords_to_plot: MutableMapping[str, Hashable | None], + kwargs: dict, + default_guess: tuple[str, ...] = ("x",), + # TODO: Can this be normalized, plt.cbook.normalize_kwargs? + ignore_guess_kwargs: tuple[tuple[str, ...], ...] = ((),), +) -> MutableMapping[str, Hashable]: + """ + Guess what coords to plot if some of the values in coords_to_plot are None which + happens when the user has not defined all available ways of visualizing + the data. + + Parameters + ---------- + darray : DataArray + The DataArray to check for available coords. + coords_to_plot : MutableMapping[str, Hashable] + Coords defined by the user to plot. + kwargs : dict + Extra kwargs that will be sent to matplotlib. + default_guess : Iterable[str], optional + Default values and order to retrieve dims if values in dims_plot is + missing, default: ("x", "hue", "size"). + ignore_guess_kwargs : tuple[tuple[str, ...], ...] + Matplotlib arguments to ignore. + + Examples + -------- + >>> ds = xr.tutorial.scatter_example_dataset(seed=42) + >>> # Only guess x by default: + >>> xr.plot.utils._guess_coords_to_plot( + ... ds.A, + ... coords_to_plot={"x": None, "z": None, "hue": None, "size": None}, + ... kwargs={}, + ... ) + {'x': 'x', 'z': None, 'hue': None, 'size': None} + + >>> # Guess all plot dims with other default values: + >>> xr.plot.utils._guess_coords_to_plot( + ... ds.A, + ... coords_to_plot={"x": None, "z": None, "hue": None, "size": None}, + ... kwargs={}, + ... default_guess=("x", "hue", "size"), + ... ignore_guess_kwargs=((), ("c", "color"), ("s",)), + ... ) + {'x': 'x', 'z': None, 'hue': 'y', 'size': 'z'} + + >>> # Don't guess ´size´, since the matplotlib kwarg ´s´ has been defined: + >>> xr.plot.utils._guess_coords_to_plot( + ... ds.A, + ... coords_to_plot={"x": None, "z": None, "hue": None, "size": None}, + ... kwargs={"s": 5}, + ... default_guess=("x", "hue", "size"), + ... ignore_guess_kwargs=((), ("c", "color"), ("s",)), + ... ) + {'x': 'x', 'z': None, 'hue': 'y', 'size': None} + + >>> # Prioritize ´size´ over ´s´: + >>> xr.plot.utils._guess_coords_to_plot( + ... ds.A, + ... coords_to_plot={"x": None, "z": None, "hue": None, "size": "x"}, + ... kwargs={"s": 5}, + ... default_guess=("x", "hue", "size"), + ... ignore_guess_kwargs=((), ("c", "color"), ("s",)), + ... ) + {'x': 'y', 'z': None, 'hue': 'z', 'size': 'x'} + """ + coords_to_plot_exist = {k: v for k, v in coords_to_plot.items() if v is not None} + available_coords = tuple( + k for k in darray.coords.keys() if k not in coords_to_plot_exist.values() + ) + + # If dims_plot[k] isn't defined then fill with one of the available dims, unless + # one of related mpl kwargs has been used. This should have similiar behaviour as + # * plt.plot(x, y) -> Multple lines with different colors if y is 2d. + # * plt.plot(x, y, color="red") -> Multiple red lines if y is 2d. + for k, dim, ign_kws in zip(default_guess, available_coords, ignore_guess_kwargs): + if coords_to_plot.get(k, None) is None and all( + kwargs.get(ign_kw, None) is None for ign_kw in ign_kws + ): + coords_to_plot[k] = dim + + for k, dim in coords_to_plot.items(): + _assert_valid_xy(darray, dim, k) + + return coords_to_plot diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 8c2b31e929e..40204691e85 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2717,7 +2717,7 @@ def test_scatter( def test_non_numeric_legend(self) -> None: ds2 = self.ds.copy() ds2["hue"] = ["a", "b", "c", "d"] - pc = ds2.plot.scatter(x="A", y="B", hue="hue") + pc = ds2.plot.scatter(x="A", y="B", markersize="hue") # should make a discrete legend assert pc.axes.legend_ is not None @@ -2725,15 +2725,9 @@ def test_legend_labels(self) -> None: # regression test for #4126: incorrect legend labels ds2 = self.ds.copy() ds2["hue"] = ["a", "a", "b", "b"] - pc = ds2.plot.scatter(x="A", y="B", hue="hue") + pc = ds2.plot.scatter(x="A", y="B", markersize="hue") actual = [t.get_text() for t in pc.axes.get_legend().texts] - expected = [ - "col [colunits]", - "$\\mathdefault{0}$", - "$\\mathdefault{1}$", - "$\\mathdefault{2}$", - "$\\mathdefault{3}$", - ] + expected = ["hue", "a", "b"] assert actual == expected def test_legend_labels_facetgrid(self) -> None: