Skip to content

Commit 0181aa5

Browse files
MeraXdcherian
andauthored
Allow plotting bool data (#3766)
* Allow plotting bool data Fixes #3722 . > matplotlib can plot `bool` values so we should add that to the check in `_ensure_plottable`. * Add tests + raise nicer error when asked to plot unsupported types * Add whats-new Co-authored-by: dcherian <deepak@cherian.net>
1 parent 19cb99d commit 0181aa5

File tree

4 files changed

+20
-6
lines changed

4 files changed

+20
-6
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ New Features
4646
- Implement :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`,
4747
:py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:issue:`60`, :pull:`3871`)
4848
By `Todd Jennings <https://github.com/toddrjen>`_
49-
49+
- Allow plotting of boolean arrays. (:pull:`3766`)
50+
By `Marek Jacob <https://github.com/MeraX>`_
5051

5152
Bug fixes
5253
~~~~~~~~~

xarray/plot/plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ def newplotfunc(
689689
xplt, xlab_extra = _resolve_intervals_2dplot(xval, plotfunc.__name__)
690690
yplt, ylab_extra = _resolve_intervals_2dplot(yval, plotfunc.__name__)
691691

692-
_ensure_plottable(xplt, yplt)
692+
_ensure_plottable(xplt, yplt, zval)
693693

694694
cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs(
695695
plotfunc, zval.data, **locals()

xarray/plot/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ def _ensure_plottable(*args):
534534
Raise exception if there is anything in args that can't be plotted on an
535535
axis by matplotlib.
536536
"""
537-
numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64]
537+
numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64, np.bool_]
538538
other_types = [datetime]
539539
try:
540540
import cftime
@@ -549,10 +549,10 @@ def _ensure_plottable(*args):
549549
or _valid_other_type(np.array(x), other_types)
550550
):
551551
raise TypeError(
552-
"Plotting requires coordinates to be numeric "
553-
"or dates of type np.datetime64, "
552+
"Plotting requires coordinates to be numeric, boolean, "
553+
"or dates of type numpy.datetime64, "
554554
"datetime.datetime, cftime.datetime or "
555-
"pd.Interval."
555+
f"pandas.Interval. Received data of type {np.array(x).dtype} instead."
556556
)
557557
if (
558558
_valid_other_type(np.array(x), cftime_datetime)

xarray/tests/test_plot.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,12 @@ def test1d(self):
139139
with raises_regex(ValueError, "None"):
140140
self.darray[:, 0, 0].plot(x="dim_1")
141141

142+
with raises_regex(TypeError, "complex128"):
143+
(self.darray[:, 0, 0] + 1j).plot()
144+
145+
def test_1d_bool(self):
146+
xr.ones_like(self.darray[:, 0, 0], dtype=np.bool).plot()
147+
142148
def test_1d_x_y_kw(self):
143149
z = np.arange(10)
144150
da = DataArray(np.cos(z), dims=["z"], coords=[z], name="f")
@@ -989,6 +995,13 @@ def test_1d_raises_valueerror(self):
989995
with raises_regex(ValueError, r"DataArray must be 2d"):
990996
self.plotfunc(self.darray[0, :])
991997

998+
def test_bool(self):
999+
xr.ones_like(self.darray, dtype=np.bool).plot()
1000+
1001+
def test_complex_raises_typeerror(self):
1002+
with raises_regex(TypeError, "complex128"):
1003+
(self.darray + 1j).plot()
1004+
9921005
def test_3d_raises_valueerror(self):
9931006
a = DataArray(easy_array((2, 3, 4)))
9941007
if self.plotfunc.__name__ == "imshow":

0 commit comments

Comments
 (0)