Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require to explicitly defining optional dimensions such as hue and markersize #7277

Merged
merged 29 commits into from
Feb 11, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2425bdc
Prioritize mpl kwargs when hue/size isn't defined.
Illviljan Nov 10, 2022
09fedc6
Update dataarray_plot.py
Illviljan Nov 10, 2022
277ec39
Merge branch 'main' into dont_guess_for_some_kwargs
Illviljan Nov 11, 2022
a2740e0
rename vars for clarity
Illviljan Nov 12, 2022
4bce4a9
Handle int coords
Illviljan Nov 13, 2022
b885289
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2022
192fd47
Update dataarray_plot.py
Illviljan Nov 13, 2022
42b7232
Merge branch 'dont_guess_for_some_kwargs' of https://github.com/Illvi…
Illviljan Nov 13, 2022
b534f9a
Move funcs to utils and use in facetgrid, fix int coords in facetgrid
Illviljan Nov 13, 2022
f6d8a67
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2022
a53c17f
Update dataarray_plot.py
Illviljan Nov 13, 2022
1bd971f
Merge branch 'dont_guess_for_some_kwargs' of https://github.com/Illvi…
Illviljan Nov 13, 2022
2ee9e47
Update utils.py
Illviljan Nov 13, 2022
cf7a016
Update utils.py
Illviljan Nov 13, 2022
c4a0c48
Update facetgrid.py
Illviljan Nov 13, 2022
f4a26da
Merge branch 'main' into dont_guess_for_some_kwargs
Illviljan Nov 20, 2022
65e2367
Merge branch 'main' into dont_guess_for_some_kwargs
Illviljan Nov 22, 2022
a5e6842
typing fixes
Illviljan Nov 23, 2022
58944eb
Only guess x-axis.
Illviljan Nov 23, 2022
04694f9
fix tests
Illviljan Nov 23, 2022
fa49f55
Merge branch 'main' into dont_guess_for_some_kwargs
Illviljan Dec 18, 2022
151b9cf
Merge branch 'main' into dont_guess_for_some_kwargs
Illviljan Jan 20, 2023
681ec75
rename function to a better name.
Illviljan Jan 26, 2023
1b3ba5e
Merge branch 'main' into dont_guess_for_some_kwargs
Illviljan Jan 26, 2023
db0f64e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2023
0c140b2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2023
d72827a
Merge branch 'main' into dont_guess_for_some_kwargs
Illviljan Feb 9, 2023
762950d
Merge branch 'main' into dont_guess_for_some_kwargs
Illviljan Feb 9, 2023
bcdd818
Update whats-new.rst
Illviljan Feb 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 51 additions & 46 deletions xarray/plot/dataarray_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_assert_valid_xy,
_determine_guide,
_ensure_plottable,
_guess_coords_to_plot,
_infer_interval_breaks,
_infer_xy_labels,
_Normalize,
Expand Down Expand Up @@ -148,48 +149,45 @@ def _infer_line_data(
return xplt, yplt, hueplt, huelabel


def _infer_plot_dims(
darray: DataArray,
dims_plot: MutableMapping[str, Hashable],
default_guess: Iterable[str] = ("x", "hue", "size"),
) -> MutableMapping[str, Hashable]:
def _infer_line_data2(
Illviljan marked this conversation as resolved.
Show resolved Hide resolved
darray: T_DataArray,
coords_to_plot: MutableMapping[str, Hashable],
plotfunc_name: str | None = None,
_is_facetgrid: bool = False,
) -> dict[str, T_DataArray]:
"""
Guess what dims to plot if some of the values in dims_plot are None which
happens when the user has not defined all available ways of visualizing
the data.
Infer data to plot.

Parameters
----------
darray : DataArray
The DataArray to check.
dims_plot : T_DimsPlot
Dims defined by the user to plot.
default_guess : Iterable[str], optional
Default values and order to retrieve dims if values in dims_plot is
missing, default: ("x", "hue", "size").
"""
dims_plot_exist = {k: v for k, v in dims_plot.items() if v is not None}
dims_avail = tuple(v for v in darray.dims if v not in dims_plot_exist.values())

# If dims_plot[k] isn't defined then fill with one of the available dims:
for k, v in zip(default_guess, dims_avail):
if dims_plot.get(k, None) is None:
dims_plot[k] = v

for k, v in dims_plot.items():
_assert_valid_xy(darray, v, k)

return dims_plot

darray : T_DataArray
Base DataArray.
coords_to_plot : MutableMapping[str, Hashable]
Coords that will be plotted.
plotfunc_name : str | None
Name of the plotting function that will be used.

def _infer_line_data2(
darray: T_DataArray,
dims_plot: MutableMapping[str, Hashable],
plotfunc_name: None | str = None,
) -> dict[str, T_DataArray]:
# Guess what dims to use if some of the values in plot_dims are None:
dims_plot = _infer_plot_dims(darray, dims_plot)
Returns
-------
plts : dict[str, T_DataArray]
Dict of DataArrays that will be sent to matplotlib.

Examples
--------
>>> # Make sure int coords are plotted:
>>> a = xr.DataArray(
... data=[1, 2],
... coords={1: ("x", [0, 1], {"units": "s"})},
... dims=("x",),
... name="a",
... )
>>> plts = xr.plot.dataarray_plot._infer_line_data2(
... a, coords_to_plot={"x": 1, "z": None, "hue": None, "size": None}
... )
>>> # Check which coords to plot:
>>> print({k: v.name for k, v in plts.items()})
{'y': 'a', 'x': 1}
"""
# If there are more than 1 dimension in the array than stack all the
# dimensions so the plotter can plot anything:
if darray.ndim > 1:
Expand All @@ -199,11 +197,11 @@ def _infer_line_data2(
dims_T = []
if np.issubdtype(darray.dtype, np.floating):
for v in ["z", "x"]:
dim = dims_plot.get(v, None)
dim = coords_to_plot.get(v, None)
if (dim is not None) and (dim in darray.dims):
darray_nan = np.nan * darray.isel({dim: -1})
darray = concat([darray, darray_nan], dim=dim)
dims_T.append(dims_plot[v])
dims_T.append(coords_to_plot[v])

# Lines should never connect to the same coordinate when stacked,
# transpose to avoid this as much as possible:
Expand All @@ -213,11 +211,13 @@ def _infer_line_data2(
darray = darray.stack(_stacked_dim=darray.dims)

# Broadcast together all the chosen variables:
out = dict(y=darray)
out.update({k: darray[v] for k, v in dims_plot.items() if v is not None})
out = dict(zip(out.keys(), broadcast(*(out.values()))))
plts = dict(y=darray)
plts.update(
{k: darray.coords[v] for k, v in coords_to_plot.items() if v is not None}
)
plts = dict(zip(plts.keys(), broadcast(*(plts.values()))))

return out
return plts


# return type is Any due to the many different possibilities
Expand Down Expand Up @@ -941,15 +941,20 @@ def newplotfunc(
_is_facetgrid = kwargs.pop("_is_facetgrid", False)

if plotfunc.__name__ == "scatter":
size_ = markersize
size_ = kwargs.pop("_size", markersize)
size_r = _MARKERSIZE_RANGE
else:
size_ = linewidth
size_ = kwargs.pop("_size", linewidth)
size_r = _LINEWIDTH_RANGE

# Get data to plot:
dims_plot = dict(x=x, z=z, hue=hue, size=size_)
plts = _infer_line_data2(darray, dims_plot, plotfunc.__name__)
coords_to_plot: MutableMapping[str, Hashable | None] = dict(
x=x, z=z, hue=hue, size=size_
)
if not _is_facetgrid:
# Guess what coords to use if some of the values in coords_to_plot are None:
coords_to_plot = _guess_coords_to_plot(darray, coords_to_plot, kwargs)
plts = _infer_line_data2(darray, coords_to_plot, plotfunc.__name__)
xplt = plts.pop("x", None)
yplt = plts.pop("y", None)
zplt = plts.pop("z", None)
Expand Down
42 changes: 29 additions & 13 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Hashable,
Iterable,
Literal,
MutableMapping,
TypeVar,
cast,
)
Expand All @@ -25,6 +26,7 @@
_add_legend,
_determine_guide,
_get_nice_quiver_magnitude,
_guess_coords_to_plot,
_infer_xy_labels,
_Normalize,
_parse_size,
Expand Down Expand Up @@ -392,6 +394,11 @@ def map_plot1d(
func: Callable,
x: Hashable | None,
y: Hashable | None,
*,
z: Hashable | None = None,
hue: Hashable | None = None,
markersize: Hashable | None = None,
linewidth: Hashable | None = None,
**kwargs: Any,
) -> T_FacetGrid:
"""
Expand Down Expand Up @@ -424,13 +431,25 @@ def map_plot1d(
if kwargs.get("cbar_ax", None) is not None:
raise ValueError("cbar_ax not supported by FacetGrid.")

if func.__name__ == "scatter":
size_ = kwargs.pop("_size", markersize)
size_r = _MARKERSIZE_RANGE
else:
size_ = kwargs.pop("_size", linewidth)
size_r = _LINEWIDTH_RANGE

# Guess what coords to use if some of the values in coords_to_plot are None:
coords_to_plot: MutableMapping[str, Hashable | None] = dict(
x=x, z=z, hue=hue, size=size_
)
coords_to_plot = _guess_coords_to_plot(self.data, coords_to_plot, kwargs)

# Handle hues:
hue = kwargs.get("hue", None)
hueplt = self.data[hue] if hue else self.data
hue = coords_to_plot["hue"]
hueplt = self.data.coords[hue] if hue else None # TODO: _infer_line_data2 ?
hueplt_norm = _Normalize(hueplt)
self._hue_var = hueplt
cbar_kwargs = kwargs.pop("cbar_kwargs", {})

if hueplt_norm.data is not None:
if not hueplt_norm.data_is_numeric:
# TODO: Ticks seems a little too hardcoded, since it will always
Expand All @@ -450,16 +469,11 @@ def map_plot1d(
cmap_params = {}

# Handle sizes:
_size_r = _MARKERSIZE_RANGE if func.__name__ == "scatter" else _LINEWIDTH_RANGE
for _size in ("markersize", "linewidth"):
size = kwargs.get(_size, None)

sizeplt = self.data[size] if size else None
sizeplt_norm = _Normalize(data=sizeplt, width=_size_r)
if size:
self.data[size] = sizeplt_norm.values
kwargs.update(**{_size: size})
break
size_ = coords_to_plot["size"]
sizeplt = self.data.coords[size_] if size_ else None
sizeplt_norm = _Normalize(data=sizeplt, width=size_r)
if sizeplt_norm.data is not None:
self.data[size_] = sizeplt_norm.values

# Add kwargs that are sent to the plotting function, # order is important ???
func_kwargs = {
Expand Down Expand Up @@ -513,6 +527,8 @@ def map_plot1d(
x=x,
y=y,
ax=ax,
hue=hue,
_size=size_,
**func_kwargs,
_is_facetgrid=True,
)
Expand Down
90 changes: 90 additions & 0 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Hashable,
Iterable,
Mapping,
MutableMapping,
Sequence,
overload,
)
Expand Down Expand Up @@ -1749,3 +1750,92 @@ def _add_legend(
_adjust_legend_subtitles(legend)

return legend


def _guess_coords_to_plot(
darray: DataArray,
coords_to_plot: MutableMapping[str, Hashable | None],
kwargs: dict,
default_guess: tuple[str, ...] = ("x",),
# TODO: Can this be normalized, plt.cbook.normalize_kwargs?
ignore_guess_kwargs: tuple[tuple[str, ...], ...] = ((),),
) -> MutableMapping[str, Hashable]:
"""
Guess what coords to plot if some of the values in coords_to_plot are None which
happens when the user has not defined all available ways of visualizing
the data.

Parameters
----------
darray : DataArray
The DataArray to check for available coords.
coords_to_plot : MutableMapping[str, Hashable]
Coords defined by the user to plot.
kwargs : dict
Extra kwargs that will be sent to matplotlib.
default_guess : Iterable[str], optional
Default values and order to retrieve dims if values in dims_plot is
missing, default: ("x", "hue", "size").
ignore_guess_kwargs : tuple[tuple[str, ...], ...]
Matplotlib arguments to ignore.

Examples
--------
>>> ds = xr.tutorial.scatter_example_dataset(seed=42)
>>> # Only guess x by default:
>>> xr.plot.utils._guess_coords_to_plot(
... ds.A,
... coords_to_plot={"x": None, "z": None, "hue": None, "size": None},
... kwargs={},
... )
{'x': 'x', 'z': None, 'hue': None, 'size': None}

>>> # Guess all plot dims with other default values:
>>> xr.plot.utils._guess_coords_to_plot(
... ds.A,
... coords_to_plot={"x": None, "z": None, "hue": None, "size": None},
... kwargs={},
... default_guess=("x", "hue", "size"),
... ignore_guess_kwargs=((), ("c", "color"), ("s",)),
... )
{'x': 'x', 'z': None, 'hue': 'y', 'size': 'z'}

>>> # Don't guess ´size´, since the matplotlib kwarg ´s´ has been defined:
>>> xr.plot.utils._guess_coords_to_plot(
... ds.A,
... coords_to_plot={"x": None, "z": None, "hue": None, "size": None},
... kwargs={"s": 5},
... default_guess=("x", "hue", "size"),
... ignore_guess_kwargs=((), ("c", "color"), ("s",)),
... )
{'x': 'x', 'z': None, 'hue': 'y', 'size': None}

>>> # Prioritize ´size´ over ´s´:
>>> xr.plot.utils._guess_coords_to_plot(
... ds.A,
... coords_to_plot={"x": None, "z": None, "hue": None, "size": "x"},
... kwargs={"s": 5},
... default_guess=("x", "hue", "size"),
... ignore_guess_kwargs=((), ("c", "color"), ("s",)),
... )
{'x': 'y', 'z': None, 'hue': 'z', 'size': 'x'}
"""
coords_to_plot_exist = {k: v for k, v in coords_to_plot.items() if v is not None}
available_coords = tuple(
k for k in darray.coords.keys() if k not in coords_to_plot_exist.values()
)

# If dims_plot[k] isn't defined then fill with one of the available dims, unless
# one of related mpl kwargs has been used. This should have similiar behaviour as
# * plt.plot(x, y) -> Multple lines with different colors if y is 2d.
# * plt.plot(x, y, color="red") -> Multiple red lines if y is 2d.
for k, dim, ign_kws in zip(default_guess, available_coords, ignore_guess_kwargs):
if coords_to_plot.get(k, None) is None and all(
kwargs.get(ign_kw, None) is None for ign_kw in ign_kws
):
coords_to_plot[k] = dim

for k, dim in coords_to_plot.items():
_assert_valid_xy(darray, dim, k)

return coords_to_plot
12 changes: 3 additions & 9 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2729,23 +2729,17 @@ def test_scatter(
def test_non_numeric_legend(self) -> None:
ds2 = self.ds.copy()
ds2["hue"] = ["a", "b", "c", "d"]
pc = ds2.plot.scatter(x="A", y="B", hue="hue")
pc = ds2.plot.scatter(x="A", y="B", markersize="hue")
# should make a discrete legend
assert pc.axes.legend_ is not None

def test_legend_labels(self) -> None:
# regression test for #4126: incorrect legend labels
ds2 = self.ds.copy()
ds2["hue"] = ["a", "a", "b", "b"]
pc = ds2.plot.scatter(x="A", y="B", hue="hue")
pc = ds2.plot.scatter(x="A", y="B", markersize="hue")
actual = [t.get_text() for t in pc.axes.get_legend().texts]
expected = [
"col [colunits]",
"$\\mathdefault{0}$",
"$\\mathdefault{1}$",
"$\\mathdefault{2}$",
"$\\mathdefault{3}$",
]
expected = ["hue", "a", "b"]
assert actual == expected

def test_legend_labels_facetgrid(self) -> None:
Expand Down