Skip to content

Commit

Permalink
Require to explicitly defining optional dimensions such as hue and ma…
Browse files Browse the repository at this point in the history
…rkersize (#7277)

* Prioritize mpl kwargs when hue/size isn't defined.

* Update dataarray_plot.py

* rename vars for clarity

* Handle int coords

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update dataarray_plot.py

* Move funcs to utils and use in facetgrid, fix int coords in facetgrid

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update dataarray_plot.py

* Update utils.py

* Update utils.py

* Update facetgrid.py

* typing fixes

* Only guess x-axis.

* fix tests

* rename function to a better name.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update whats-new.rst

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Illviljan and pre-commit-ci[bot] authored Feb 11, 2023
1 parent 7683442 commit 9ff932a
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 74 deletions.
10 changes: 6 additions & 4 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ Deprecations

Bug fixes
~~~~~~~~~

- Require to explicitly defining optional dimensions such as hue
and markersize for scatter plots. (:issue:`7314`, :pull:`7277`).
By `Jimmy Westling <https://github.com/illviljan>`_.
- Fix matplotlib raising a UserWarning when plotting a scatter plot
with an unfilled marker (:issue:`7313`, :pull:`7318`).
By `Jimmy Westling <https://github.com/illviljan>`_.

Documentation
~~~~~~~~~~~~~
Expand Down Expand Up @@ -194,9 +199,6 @@ Bug fixes
By `Michael Niklas <https://github.com/headtr1ck>`_.
- Fix static typing of :py:meth:`xr.polyval` (:issue:`7312`, :pull:`7315`).
By `Michael Niklas <https://github.com/headtr1ck>`_.
- Fix matplotlib raising a UserWarning when plotting a scatter plot
with an unfilled marker (:issue:`7313`, :pull:`7318`).
By `Jimmy Westling <https://github.com/illviljan>`_.
- Fix multiple reads on fsspec S3 files by resetting file pointer to 0 when reading file streams (:issue:`6813`, :pull:`7304`).
By `David Hoese <https://github.com/djhoese>`_ and `Wei Ji Leong <https://github.com/weiji14>`_.
- Fix :py:meth:`Dataset.assign_coords` resetting all dimension coordinates to default (pandas) index (:issue:`7346`, :pull:`7347`).
Expand Down
97 changes: 51 additions & 46 deletions xarray/plot/dataarray_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
_assert_valid_xy,
_determine_guide,
_ensure_plottable,
_guess_coords_to_plot,
_infer_interval_breaks,
_infer_xy_labels,
_Normalize,
Expand Down Expand Up @@ -142,48 +143,45 @@ def _infer_line_data(
return xplt, yplt, hueplt, huelabel


def _infer_plot_dims(
darray: DataArray,
dims_plot: MutableMapping[str, Hashable],
default_guess: Iterable[str] = ("x", "hue", "size"),
) -> MutableMapping[str, Hashable]:
def _prepare_plot1d_data(
darray: T_DataArray,
coords_to_plot: MutableMapping[str, Hashable],
plotfunc_name: str | None = None,
_is_facetgrid: bool = False,
) -> dict[str, T_DataArray]:
"""
Guess what dims to plot if some of the values in dims_plot are None which
happens when the user has not defined all available ways of visualizing
the data.
Prepare data for usage with plt.scatter.
Parameters
----------
darray : DataArray
The DataArray to check.
dims_plot : T_DimsPlot
Dims defined by the user to plot.
default_guess : Iterable[str], optional
Default values and order to retrieve dims if values in dims_plot is
missing, default: ("x", "hue", "size").
"""
dims_plot_exist = {k: v for k, v in dims_plot.items() if v is not None}
dims_avail = tuple(v for v in darray.dims if v not in dims_plot_exist.values())

# If dims_plot[k] isn't defined then fill with one of the available dims:
for k, v in zip(default_guess, dims_avail):
if dims_plot.get(k, None) is None:
dims_plot[k] = v

for k, v in dims_plot.items():
_assert_valid_xy(darray, v, k)

return dims_plot

darray : T_DataArray
Base DataArray.
coords_to_plot : MutableMapping[str, Hashable]
Coords that will be plotted.
plotfunc_name : str | None
Name of the plotting function that will be used.
def _infer_line_data2(
darray: T_DataArray,
dims_plot: MutableMapping[str, Hashable],
plotfunc_name: None | str = None,
) -> dict[str, T_DataArray]:
# Guess what dims to use if some of the values in plot_dims are None:
dims_plot = _infer_plot_dims(darray, dims_plot)
Returns
-------
plts : dict[str, T_DataArray]
Dict of DataArrays that will be sent to matplotlib.
Examples
--------
>>> # Make sure int coords are plotted:
>>> a = xr.DataArray(
... data=[1, 2],
... coords={1: ("x", [0, 1], {"units": "s"})},
... dims=("x",),
... name="a",
... )
>>> plts = xr.plot.dataarray_plot._prepare_plot1d_data(
... a, coords_to_plot={"x": 1, "z": None, "hue": None, "size": None}
... )
>>> # Check which coords to plot:
>>> print({k: v.name for k, v in plts.items()})
{'y': 'a', 'x': 1}
"""
# If there are more than 1 dimension in the array than stack all the
# dimensions so the plotter can plot anything:
if darray.ndim > 1:
Expand All @@ -193,11 +191,11 @@ def _infer_line_data2(
dims_T = []
if np.issubdtype(darray.dtype, np.floating):
for v in ["z", "x"]:
dim = dims_plot.get(v, None)
dim = coords_to_plot.get(v, None)
if (dim is not None) and (dim in darray.dims):
darray_nan = np.nan * darray.isel({dim: -1})
darray = concat([darray, darray_nan], dim=dim)
dims_T.append(dims_plot[v])
dims_T.append(coords_to_plot[v])

# Lines should never connect to the same coordinate when stacked,
# transpose to avoid this as much as possible:
Expand All @@ -207,11 +205,13 @@ def _infer_line_data2(
darray = darray.stack(_stacked_dim=darray.dims)

# Broadcast together all the chosen variables:
out = dict(y=darray)
out.update({k: darray[v] for k, v in dims_plot.items() if v is not None})
out = dict(zip(out.keys(), broadcast(*(out.values()))))
plts = dict(y=darray)
plts.update(
{k: darray.coords[v] for k, v in coords_to_plot.items() if v is not None}
)
plts = dict(zip(plts.keys(), broadcast(*(plts.values()))))

return out
return plts


# return type is Any due to the many different possibilities
Expand Down Expand Up @@ -938,15 +938,20 @@ def newplotfunc(
_is_facetgrid = kwargs.pop("_is_facetgrid", False)

if plotfunc.__name__ == "scatter":
size_ = markersize
size_ = kwargs.pop("_size", markersize)
size_r = _MARKERSIZE_RANGE
else:
size_ = linewidth
size_ = kwargs.pop("_size", linewidth)
size_r = _LINEWIDTH_RANGE

# Get data to plot:
dims_plot = dict(x=x, z=z, hue=hue, size=size_)
plts = _infer_line_data2(darray, dims_plot, plotfunc.__name__)
coords_to_plot: MutableMapping[str, Hashable | None] = dict(
x=x, z=z, hue=hue, size=size_
)
if not _is_facetgrid:
# Guess what coords to use if some of the values in coords_to_plot are None:
coords_to_plot = _guess_coords_to_plot(darray, coords_to_plot, kwargs)
plts = _prepare_plot1d_data(darray, coords_to_plot, plotfunc.__name__)
xplt = plts.pop("x", None)
yplt = plts.pop("y", None)
zplt = plts.pop("z", None)
Expand Down
43 changes: 29 additions & 14 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import functools
import itertools
import warnings
from collections.abc import Hashable, Iterable
from collections.abc import Hashable, Iterable, MutableMapping
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar, cast

import numpy as np
Expand All @@ -16,6 +16,7 @@
_add_legend,
_determine_guide,
_get_nice_quiver_magnitude,
_guess_coords_to_plot,
_infer_xy_labels,
_Normalize,
_parse_size,
Expand Down Expand Up @@ -383,6 +384,11 @@ def map_plot1d(
func: Callable,
x: Hashable | None,
y: Hashable | None,
*,
z: Hashable | None = None,
hue: Hashable | None = None,
markersize: Hashable | None = None,
linewidth: Hashable | None = None,
**kwargs: Any,
) -> T_FacetGrid:
"""
Expand Down Expand Up @@ -415,13 +421,25 @@ def map_plot1d(
if kwargs.get("cbar_ax", None) is not None:
raise ValueError("cbar_ax not supported by FacetGrid.")

if func.__name__ == "scatter":
size_ = kwargs.pop("_size", markersize)
size_r = _MARKERSIZE_RANGE
else:
size_ = kwargs.pop("_size", linewidth)
size_r = _LINEWIDTH_RANGE

# Guess what coords to use if some of the values in coords_to_plot are None:
coords_to_plot: MutableMapping[str, Hashable | None] = dict(
x=x, z=z, hue=hue, size=size_
)
coords_to_plot = _guess_coords_to_plot(self.data, coords_to_plot, kwargs)

# Handle hues:
hue = kwargs.get("hue", None)
hueplt = self.data[hue] if hue else self.data
hue = coords_to_plot["hue"]
hueplt = self.data.coords[hue] if hue else None # TODO: _infer_line_data2 ?
hueplt_norm = _Normalize(hueplt)
self._hue_var = hueplt
cbar_kwargs = kwargs.pop("cbar_kwargs", {})

if hueplt_norm.data is not None:
if not hueplt_norm.data_is_numeric:
# TODO: Ticks seems a little too hardcoded, since it will always
Expand All @@ -441,16 +459,11 @@ def map_plot1d(
cmap_params = {}

# Handle sizes:
_size_r = _MARKERSIZE_RANGE if func.__name__ == "scatter" else _LINEWIDTH_RANGE
for _size in ("markersize", "linewidth"):
size = kwargs.get(_size, None)

sizeplt = self.data[size] if size else None
sizeplt_norm = _Normalize(data=sizeplt, width=_size_r)
if size:
self.data[size] = sizeplt_norm.values
kwargs.update(**{_size: size})
break
size_ = coords_to_plot["size"]
sizeplt = self.data.coords[size_] if size_ else None
sizeplt_norm = _Normalize(data=sizeplt, width=size_r)
if sizeplt_norm.data is not None:
self.data[size_] = sizeplt_norm.values

# Add kwargs that are sent to the plotting function, # order is important ???
func_kwargs = {
Expand Down Expand Up @@ -504,6 +517,8 @@ def map_plot1d(
x=x,
y=y,
ax=ax,
hue=hue,
_size=size_,
**func_kwargs,
_is_facetgrid=True,
)
Expand Down
91 changes: 90 additions & 1 deletion xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import itertools
import textwrap
import warnings
from collections.abc import Hashable, Iterable, Mapping, Sequence
from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence
from datetime import datetime
from inspect import getfullargspec
from typing import TYPE_CHECKING, Any, Callable, overload
Expand Down Expand Up @@ -1735,3 +1735,92 @@ def _add_legend(
_adjust_legend_subtitles(legend)

return legend


def _guess_coords_to_plot(
darray: DataArray,
coords_to_plot: MutableMapping[str, Hashable | None],
kwargs: dict,
default_guess: tuple[str, ...] = ("x",),
# TODO: Can this be normalized, plt.cbook.normalize_kwargs?
ignore_guess_kwargs: tuple[tuple[str, ...], ...] = ((),),
) -> MutableMapping[str, Hashable]:
"""
Guess what coords to plot if some of the values in coords_to_plot are None which
happens when the user has not defined all available ways of visualizing
the data.
Parameters
----------
darray : DataArray
The DataArray to check for available coords.
coords_to_plot : MutableMapping[str, Hashable]
Coords defined by the user to plot.
kwargs : dict
Extra kwargs that will be sent to matplotlib.
default_guess : Iterable[str], optional
Default values and order to retrieve dims if values in dims_plot is
missing, default: ("x", "hue", "size").
ignore_guess_kwargs : tuple[tuple[str, ...], ...]
Matplotlib arguments to ignore.
Examples
--------
>>> ds = xr.tutorial.scatter_example_dataset(seed=42)
>>> # Only guess x by default:
>>> xr.plot.utils._guess_coords_to_plot(
... ds.A,
... coords_to_plot={"x": None, "z": None, "hue": None, "size": None},
... kwargs={},
... )
{'x': 'x', 'z': None, 'hue': None, 'size': None}
>>> # Guess all plot dims with other default values:
>>> xr.plot.utils._guess_coords_to_plot(
... ds.A,
... coords_to_plot={"x": None, "z": None, "hue": None, "size": None},
... kwargs={},
... default_guess=("x", "hue", "size"),
... ignore_guess_kwargs=((), ("c", "color"), ("s",)),
... )
{'x': 'x', 'z': None, 'hue': 'y', 'size': 'z'}
>>> # Don't guess ´size´, since the matplotlib kwarg ´s´ has been defined:
>>> xr.plot.utils._guess_coords_to_plot(
... ds.A,
... coords_to_plot={"x": None, "z": None, "hue": None, "size": None},
... kwargs={"s": 5},
... default_guess=("x", "hue", "size"),
... ignore_guess_kwargs=((), ("c", "color"), ("s",)),
... )
{'x': 'x', 'z': None, 'hue': 'y', 'size': None}
>>> # Prioritize ´size´ over ´s´:
>>> xr.plot.utils._guess_coords_to_plot(
... ds.A,
... coords_to_plot={"x": None, "z": None, "hue": None, "size": "x"},
... kwargs={"s": 5},
... default_guess=("x", "hue", "size"),
... ignore_guess_kwargs=((), ("c", "color"), ("s",)),
... )
{'x': 'y', 'z': None, 'hue': 'z', 'size': 'x'}
"""
coords_to_plot_exist = {k: v for k, v in coords_to_plot.items() if v is not None}
available_coords = tuple(
k for k in darray.coords.keys() if k not in coords_to_plot_exist.values()
)

# If dims_plot[k] isn't defined then fill with one of the available dims, unless
# one of related mpl kwargs has been used. This should have similiar behaviour as
# * plt.plot(x, y) -> Multple lines with different colors if y is 2d.
# * plt.plot(x, y, color="red") -> Multiple red lines if y is 2d.
for k, dim, ign_kws in zip(default_guess, available_coords, ignore_guess_kwargs):
if coords_to_plot.get(k, None) is None and all(
kwargs.get(ign_kw, None) is None for ign_kw in ign_kws
):
coords_to_plot[k] = dim

for k, dim in coords_to_plot.items():
_assert_valid_xy(darray, dim, k)

return coords_to_plot
12 changes: 3 additions & 9 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2717,23 +2717,17 @@ def test_scatter(
def test_non_numeric_legend(self) -> None:
ds2 = self.ds.copy()
ds2["hue"] = ["a", "b", "c", "d"]
pc = ds2.plot.scatter(x="A", y="B", hue="hue")
pc = ds2.plot.scatter(x="A", y="B", markersize="hue")
# should make a discrete legend
assert pc.axes.legend_ is not None

def test_legend_labels(self) -> None:
# regression test for #4126: incorrect legend labels
ds2 = self.ds.copy()
ds2["hue"] = ["a", "a", "b", "b"]
pc = ds2.plot.scatter(x="A", y="B", hue="hue")
pc = ds2.plot.scatter(x="A", y="B", markersize="hue")
actual = [t.get_text() for t in pc.axes.get_legend().texts]
expected = [
"col [colunits]",
"$\\mathdefault{0}$",
"$\\mathdefault{1}$",
"$\\mathdefault{2}$",
"$\\mathdefault{3}$",
]
expected = ["hue", "a", "b"]
assert actual == expected

def test_legend_labels_facetgrid(self) -> None:
Expand Down

0 comments on commit 9ff932a

Please sign in to comment.