Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plots get labels from pint arrays #5561

Merged
merged 38 commits into from
Jul 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
ec3040e
test labels come from pint units
TomNicholas Jul 2, 2021
6bb4298
values demotes pint arrays before returning
TomNicholas Jul 2, 2021
04d78c5
plot labels look for pint units first
TomNicholas Jul 2, 2021
6fee339
pre-commit
TomNicholas Jul 2, 2021
17c5755
added to_numpy() and as_numpy() methods
TomNicholas Jul 2, 2021
48ba107
remove special-casing of cupy arrays in .values in favour of using .t…
TomNicholas Jul 2, 2021
531385b
merged to_numpy() method in
TomNicholas Jul 2, 2021
9cb2f9b
.values -> .to_numpy()
TomNicholas Jul 2, 2021
ae6e931
lint
max-sixty Jul 2, 2021
dc24d3f
Fix mypy (I think?)
max-sixty Jul 2, 2021
6ce6b05
Merge branch 'main' of https://github.com/pydata/xarray into to_numpy
TomNicholas Jul 3, 2021
04d7b02
Merge branch 'to_numpy' of https://github.com/TomNicholas/xarray into…
TomNicholas Jul 3, 2021
ee34649
added Dataset.as_numpy()
TomNicholas Jul 3, 2021
552b322
improved docstrings
TomNicholas Jul 3, 2021
1215e69
add what's new
TomNicholas Jul 3, 2021
af8a1ee
add to API docs
TomNicholas Jul 3, 2021
e095bf0
linting
TomNicholas Jul 3, 2021
eb7d84d
fix failures by only importing pint when needed
TomNicholas Jul 7, 2021
4d43f17
merge fix for pint import errors
TomNicholas Jul 7, 2021
74c05e3
refactor pycompat into class
TomNicholas Jul 7, 2021
7e5e928
Merge pyompat refactor from branch 'to_numpy' into unit-free-values
TomNicholas Jul 7, 2021
3f85e21
pycompat import changes applied to plotting code
TomNicholas Jul 7, 2021
e397168
what's new
TomNicholas Jul 7, 2021
45245d0
compute instead of load
TomNicholas Jul 8, 2021
27fc4e5
added tests
TomNicholas Jul 8, 2021
3e8cb24
fixed sparse test
TomNicholas Jul 8, 2021
f9d6370
tests and fixes for ds.as_numpy()
TomNicholas Jul 9, 2021
50fdf4c
fix sparse tests
TomNicholas Jul 9, 2021
1c94a97
fix linting
TomNicholas Jul 9, 2021
2d07c0f
tests for Variable
TomNicholas Jul 9, 2021
9673cea
test IndexVariable too
TomNicholas Jul 9, 2021
0d624cc
use numpy.asarray to avoid a copy
TomNicholas Jul 12, 2021
2f1ff46
also convert coords
TomNicholas Jul 14, 2021
afd35e2
Merge branch 'main' into to_numpy
TomNicholas Jul 15, 2021
6d33b35
Force tests again after #5600
TomNicholas Jul 16, 2021
eae95f5
Merge branch 'main' into to_numpy
TomNicholas Jul 16, 2021
36f3bd9
Merge branch 'to_numpy' into unit-free-values
TomNicholas Jul 16, 2021
4c53790
merged main
TomNicholas Jul 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ New Features
By `Elle Smith <https://github.com/ellesmith88>`_.
- Added :py:meth:`DataArray.to_numpy`, :py:meth:`DataArray.as_numpy`, and :py:meth:`Dataset.as_numpy`. (:pull:`5568`).
By `Tom Nicholas <https://github.com/TomNicholas>`_.
- Units in plot labels are now automatically inferred from wrapped :py:meth:`pint.Quantity` arrays. (:pull:`5561`).
By `Tom Nicholas <https://github.com/TomNicholas>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2784,7 +2784,7 @@ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray:
result : MaskedArray
Masked where invalid values (nan or inf) occur.
"""
values = self.values # only compute lazy arrays once
values = self.to_numpy() # only compute lazy arrays once
isnull = pd.isnull(values)
return np.ma.MaskedArray(data=values, mask=isnull, copy=copy)

Expand Down
1 change: 1 addition & 0 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,7 @@ def to_numpy(self) -> np.ndarray:
"""Coerces wrapped data to numpy and returns a numpy.ndarray"""
# TODO an entrypoint so array libraries can choose coercion method?
data = self.data

# TODO first attempt to call .to_numpy() once some libraries implement it
if isinstance(data, dask_array_type):
data = data.compute()
Expand Down
10 changes: 5 additions & 5 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ def line(

# Remove pd.Intervals if contained in xplt.values and/or yplt.values.
xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot(
xplt.values, yplt.values, kwargs
xplt.to_numpy(), yplt.to_numpy(), kwargs
)
xlabel = label_from_attrs(xplt, extra=x_suffix)
ylabel = label_from_attrs(yplt, extra=y_suffix)
Expand All @@ -449,7 +449,7 @@ def line(
ax.set_title(darray._title_for_slice())

if darray.ndim == 2 and add_legend:
ax.legend(handles=primitive, labels=list(hueplt.values), title=hue_label)
ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label)

# Rotate dates on xlabels
# Do this without calling autofmt_xdate so that x-axes ticks
Expand Down Expand Up @@ -551,7 +551,7 @@ def hist(
"""
ax = get_axis(figsize, size, aspect, ax)

no_nan = np.ravel(darray.values)
no_nan = np.ravel(darray.to_numpy())
no_nan = no_nan[pd.notnull(no_nan)]

primitive = ax.hist(no_nan, **kwargs)
Expand Down Expand Up @@ -1153,8 +1153,8 @@ def newplotfunc(
dims = (yval.dims[0], xval.dims[0])

# better to pass the ndarrays directly to plotting functions
xval = xval.values
yval = yval.values
xval = xval.to_numpy()
yval = yval.to_numpy()

# May need to transpose for correct x, y labels
# xlab may be the name of a coord, we have to check for dim names
Expand Down
21 changes: 15 additions & 6 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pandas as pd

from ..core.options import OPTIONS
from ..core.pycompat import DuckArrayModule
from ..core.utils import is_scalar

try:
Expand Down Expand Up @@ -474,12 +475,20 @@ def label_from_attrs(da, extra=""):
else:
name = ""

if da.attrs.get("units"):
units = " [{}]".format(da.attrs["units"])
elif da.attrs.get("unit"):
units = " [{}]".format(da.attrs["unit"])
def _get_units_from_attrs(da):
if da.attrs.get("units"):
units = " [{}]".format(da.attrs["units"])
elif da.attrs.get("unit"):
units = " [{}]".format(da.attrs["unit"])
else:
units = ""
return units

pint_array_type = DuckArrayModule("pint").type
if isinstance(da.data, pint_array_type):
units = " [{}]".format(str(da.data.units))
else:
units = ""
units = _get_units_from_attrs(da)

return "\n".join(textwrap.wrap(name + extra + units, 30))

Expand Down Expand Up @@ -896,7 +905,7 @@ def _get_nice_quiver_magnitude(u, v):
import matplotlib as mpl

ticker = mpl.ticker.MaxNLocator(3)
mean = np.mean(np.hypot(u.values, v.values))
mean = np.mean(np.hypot(u.to_numpy(), v.to_numpy()))
magnitude = ticker.tick_values(0, mean)[-2]
return magnitude

Expand Down
40 changes: 39 additions & 1 deletion xarray/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,22 @@
import pandas as pd
import pytest

try:
import matplotlib.pyplot as plt
except ImportError:
pass

import xarray as xr
from xarray.core import dtypes, duck_array_ops

from . import assert_allclose, assert_duckarray_allclose, assert_equal, assert_identical
from . import (
assert_allclose,
assert_duckarray_allclose,
assert_equal,
assert_identical,
requires_matplotlib,
)
from .test_plot import PlotTestCase
from .test_variable import _PAD_XR_NP_ARGS

pint = pytest.importorskip("pint")
Expand Down Expand Up @@ -5564,3 +5576,29 @@ def test_merge(self, variant, unit, error, dtype):

assert_units_equal(expected, actual)
assert_equal(expected, actual)


@requires_matplotlib
class TestPlots(PlotTestCase):
def test_units_in_line_plot_labels(self):
arr = np.linspace(1, 10, 3) * unit_registry.Pa
# TODO make coord a Quantity once unit-aware indexes supported
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't it be possible to specify a non-dimensional coordinate with x="x_coord"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The point was that I want this test to eventually test labels on both the x and y axes, but at the moment pint is only involved with the y axis, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's right. I meant to implement this by converting the dimension coordinate to a non-dimension coordinate:

x_coord = xr.DataArray(
    np.linspace(1, 3, 3) * unit_registry.m, dims="x"
)
da = xr.DataArray(data=arr, dims="x", coords={"x_coord": x_coord}, name="pressure")
da.plot.line(x="x_coord")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I try that I still get a UnitStrippedWarning in .to_index_variable()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems like a bug. open an issue for it so we don't forget?

x_coord = xr.DataArray(
np.linspace(1, 3, 3), dims="x", attrs={"units": "meters"}
)
da = xr.DataArray(data=arr, dims="x", coords={"x": x_coord}, name="pressure")

da.plot.line()

ax = plt.gca()
assert ax.get_ylabel() == "pressure [pascal]"
assert ax.get_xlabel() == "x [meters]"

def test_units_in_2d_plot_labels(self):
arr = np.ones((2, 3)) * unit_registry.Pa
da = xr.DataArray(data=arr, dims=["x", "y"], name="pressure")

fig, (ax, cax) = plt.subplots(1, 2)
ax = da.plot.contourf(ax=ax, cbar_ax=cax, add_colorbar=True)

assert cax.get_ylabel() == "pressure [pascal]"