Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use .to_numpy() for quantified facetgrids #5886

Merged
merged 10 commits into from
Oct 28, 2021
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ Bug fixes
By `Jimmy Westling <https://github.com/illviljan>`_.
- Numbers are properly formatted in a plot's title (:issue:`5788`, :pull:`5789`).
By `Maxime Liquet <https://github.com/maximlt>`_.
- Faceted plots will no longer raise a `pint.UnitStrippedWarning` when a `pint.Quantity` array is plotted,
and will correctly display the units of the data in the colorbar (if there is one) (:pull:`5886`).
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- With backends, check for path-like objects rather than ``pathlib.Path``
type, use ``os.fspath`` (:pull:`5879`).
By `Mike Taves <https://github.com/mwtoews>`_.
Expand Down
14 changes: 7 additions & 7 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,11 @@ def __init__(
)

# Set up the lists of names for the row and column facet variables
col_names = list(data[col].values) if col else []
row_names = list(data[row].values) if row else []
col_names = list(data[col].to_numpy()) if col else []
row_names = list(data[row].to_numpy()) if row else []

if single_group:
full = [{single_group: x} for x in data[single_group].values]
full = [{single_group: x} for x in data[single_group].to_numpy()]
empty = [None for x in range(nrow * ncol - len(full))]
name_dicts = full + empty
else:
Expand Down Expand Up @@ -251,7 +251,7 @@ def map_dataarray(self, func, x, y, **kwargs):
raise ValueError("cbar_ax not supported by FacetGrid.")

cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
func, self.data.values, **kwargs
func, self.data.to_numpy(), **kwargs
)

self._cmap_extend = cmap_params.get("extend")
Expand Down Expand Up @@ -347,7 +347,7 @@ def map_dataset(

if hue and meta_data["hue_style"] == "continuous":
cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
func, self.data[hue].values, **kwargs
func, self.data[hue].to_numpy(), **kwargs
)
kwargs["meta_data"]["cmap_params"] = cmap_params
kwargs["meta_data"]["cbar_kwargs"] = cbar_kwargs
Expand Down Expand Up @@ -423,7 +423,7 @@ def _adjust_fig_for_guide(self, guide):
def add_legend(self, **kwargs):
self.figlegend = self.fig.legend(
handles=self._mappables[-1],
labels=list(self._hue_var.values),
labels=list(self._hue_var.to_numpy()),
title=self._hue_label,
loc="center right",
**kwargs,
Expand Down Expand Up @@ -619,7 +619,7 @@ def map(self, func, *args, **kwargs):
if namedict is not None:
data = self.data.loc[namedict]
plt.sca(ax)
innerargs = [data[a].values for a in args]
innerargs = [data[a].to_numpy() for a in args]
maybe_mappable = func(*innerargs, **kwargs)
# TODO: better way to verify that an artist is mappable?
# https://stackoverflow.com/questions/33023036/is-it-possible-to-detect-if-a-matplotlib-artist-is-a-mappable-suitable-for-use-w#33023522
Expand Down
10 changes: 5 additions & 5 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ def newplotfunc(
# Matplotlib does not support normalising RGB data, so do it here.
# See eg. https://github.com/matplotlib/matplotlib/pull/10220
if robust or vmax is not None or vmin is not None:
darray = _rescale_imshow_rgb(darray, vmin, vmax, robust)
darray = _rescale_imshow_rgb(darray.as_numpy(), vmin, vmax, robust)
vmin, vmax, robust = None, None, False

if subplot_kws is None:
Expand Down Expand Up @@ -1146,10 +1146,6 @@ def newplotfunc(
else:
dims = (yval.dims[0], xval.dims[0])

# better to pass the ndarrays directly to plotting functions
xval = xval.to_numpy()
yval = yval.to_numpy()

# May need to transpose for correct x, y labels
# xlab may be the name of a coord, we have to check for dim names
if imshow_rgb:
Expand All @@ -1162,6 +1158,10 @@ def newplotfunc(
if dims != darray.dims:
darray = darray.transpose(*dims, transpose_coords=True)

# better to pass the ndarrays directly to plotting functions
xval = xval.to_numpy()
yval = yval.to_numpy()

# Pass the data as a masked ndarray too
zval = darray.to_masked_array(copy=False)

Expand Down
26 changes: 25 additions & 1 deletion xarray/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -5614,11 +5614,35 @@ def test_units_in_line_plot_labels(self):
assert ax.get_ylabel() == "pressure [pascal]"
assert ax.get_xlabel() == "x [meters]"

def test_units_in_2d_plot_labels(self):
def test_units_in_2d_plot_colorbar_label(self):
arr = np.ones((2, 3)) * unit_registry.Pa
da = xr.DataArray(data=arr, dims=["x", "y"], name="pressure")

fig, (ax, cax) = plt.subplots(1, 2)
ax = da.plot.contourf(ax=ax, cbar_ax=cax, add_colorbar=True)

assert cax.get_ylabel() == "pressure [pascal]"

def test_units_facetgrid_plot_labels(self):
arr = np.ones((2, 3)) * unit_registry.Pa
da = xr.DataArray(data=arr, dims=["x", "y"], name="pressure")

fig, (ax, cax) = plt.subplots(1, 2)
fgrid = da.plot.line(x="x", col="y")

assert fgrid.axes[0, 0].get_ylabel() == "pressure [pascal]"

def test_units_facetgrid_2d_imshow_plot_colorbar_labels(self):
arr = np.ones((2, 3, 4, 5)) * unit_registry.Pa
da = xr.DataArray(data=arr, dims=["x", "y", "z", "w"], name="pressure")

da.plot.imshow(x="x", y="y", col="w") # no colorbar to check labels of

def test_units_facetgrid_2d_contourf_plot_colorbar_labels(self):
arr = np.ones((2, 3, 4)) * unit_registry.Pa
da = xr.DataArray(data=arr, dims=["x", "y", "z"], name="pressure")

fig, (ax1, ax2, ax3, cax) = plt.subplots(1, 4)
fgrid = da.plot.contourf(x="x", y="y", col="z")

assert fgrid.cbar.ax.get_ylabel() == "pressure [pascal]"