Skip to content

Commit 434f9e8

Browse files
authored
Fix step plots with hue (#6944)
1 parent 15c6182 commit 434f9e8

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ Bug fixes
5353
By `Michael Niklas <https://github.com/headtr1ck>`_.
5454
- Harmonize returned multi-indexed indexes when applying ``concat`` along new dimension (:issue:`6881`, :pull:`6889`)
5555
By `Fabian Hofmann <https://github.com/FabianHofmann>`_.
56+
- Fix step plots with ``hue`` arg. (:pull:`6944`)
57+
By `András Gunyhó <https://github.com/mgunyho>`_.
5658

5759
Documentation
5860
~~~~~~~~~~~~~

xarray/plot/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -564,13 +564,16 @@ def _resolve_intervals_1dplot(
564564
if kwargs.get("drawstyle", "").startswith("steps-"):
565565

566566
remove_drawstyle = False
567+
567568
# Convert intervals to double points
568-
if _valid_other_type(np.array([xval, yval]), [pd.Interval]):
569+
x_is_interval = _valid_other_type(xval, [pd.Interval])
570+
y_is_interval = _valid_other_type(yval, [pd.Interval])
571+
if x_is_interval and y_is_interval:
569572
raise TypeError("Can't step plot intervals against intervals.")
570-
if _valid_other_type(xval, [pd.Interval]):
573+
elif x_is_interval:
571574
xval, yval = _interval_to_double_bound_points(xval, yval)
572575
remove_drawstyle = True
573-
if _valid_other_type(yval, [pd.Interval]):
576+
elif y_is_interval:
574577
yval, xval = _interval_to_double_bound_points(yval, xval)
575578
remove_drawstyle = True
576579

xarray/tests/test_plot.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -796,6 +796,24 @@ def test_step_with_where(self, where):
796796
hdl = self.darray[0, 0].plot.step(where=where)
797797
assert hdl[0].get_drawstyle() == f"steps-{where}"
798798

799+
def test_step_with_hue(self):
800+
hdl = self.darray[0].plot.step(hue="dim_2")
801+
assert hdl[0].get_drawstyle() == "steps-pre"
802+
803+
@pytest.mark.parametrize("where", ["pre", "post", "mid"])
804+
def test_step_with_hue_and_where(self, where):
805+
hdl = self.darray[0].plot.step(hue="dim_2", where=where)
806+
assert hdl[0].get_drawstyle() == f"steps-{where}"
807+
808+
def test_drawstyle_steps(self):
809+
hdl = self.darray[0].plot(hue="dim_2", drawstyle="steps")
810+
assert hdl[0].get_drawstyle() == "steps"
811+
812+
@pytest.mark.parametrize("where", ["pre", "post", "mid"])
813+
def test_drawstyle_steps_with_where(self, where):
814+
hdl = self.darray[0].plot(hue="dim_2", drawstyle=f"steps-{where}")
815+
assert hdl[0].get_drawstyle() == f"steps-{where}"
816+
799817
def test_coord_with_interval_step(self):
800818
"""Test step plot with intervals."""
801819
bins = [-1, 0, 1, 2]
@@ -814,6 +832,15 @@ def test_coord_with_interval_step_y(self):
814832
self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(y="dim_0_bins")
815833
assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2)
816834

835+
def test_coord_with_interval_step_x_and_y_raises_valueeerror(self):
836+
"""Test that step plot with intervals both on x and y axes raises an error."""
837+
arr = xr.DataArray(
838+
[pd.Interval(0, 1), pd.Interval(1, 2)],
839+
coords=[("x", [pd.Interval(0, 1), pd.Interval(1, 2)])],
840+
)
841+
with pytest.raises(TypeError, match="intervals against intervals"):
842+
arr.plot.step()
843+
817844

818845
class TestPlotHistogram(PlotTestCase):
819846
@pytest.fixture(autouse=True)

0 commit comments

Comments
 (0)