Skip to content

Commit 56a4d57

Browse files
authored
REF: make plotting less stateful (#55837)
1 parent 5d82d8b commit 56a4d57

File tree

3 files changed

+30
-26
lines changed

3 files changed

+30
-26
lines changed

pandas/plotting/_matplotlib/boxplot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from collections.abc import Collection
3636

3737
from matplotlib.axes import Axes
38+
from matplotlib.figure import Figure
3839
from matplotlib.lines import Line2D
3940

4041
from pandas._typing import MatplotlibColor
@@ -177,7 +178,7 @@ def maybe_color_bp(self, bp) -> None:
177178
if not self.kwds.get("capprops"):
178179
setp(bp["caps"], color=caps, alpha=1)
179180

180-
def _make_plot(self) -> None:
181+
def _make_plot(self, fig: Figure) -> None:
181182
if self.subplots:
182183
self._return_obj = pd.Series(dtype=object)
183184

pandas/plotting/_matplotlib/core.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
from matplotlib.artist import Artist
8181
from matplotlib.axes import Axes
8282
from matplotlib.axis import Axis
83+
from matplotlib.figure import Figure
8384

8485
from pandas._typing import (
8586
IndexLabel,
@@ -241,7 +242,8 @@ def __init__(
241242
self.stacked = kwds.pop("stacked", False)
242243

243244
self.ax = ax
244-
self.fig = fig
245+
# TODO: deprecate fig keyword as it is ignored, not passed in tests
246+
# as of 2023-11-05
245247
self.axes = np.array([], dtype=object) # "real" version get set in `generate`
246248

247249
# parse errorbar input if given
@@ -449,11 +451,11 @@ def draw(self) -> None:
449451
def generate(self) -> None:
450452
self._args_adjust()
451453
self._compute_plot_data()
452-
self._setup_subplots()
453-
self._make_plot()
454+
fig = self._setup_subplots()
455+
self._make_plot(fig)
454456
self._add_table()
455457
self._make_legend()
456-
self._adorn_subplots()
458+
self._adorn_subplots(fig)
457459

458460
for ax in self.axes:
459461
self._post_plot_logic_common(ax, self.data)
@@ -495,7 +497,7 @@ def _maybe_right_yaxis(self, ax: Axes, axes_num):
495497
new_ax.set_yscale("symlog")
496498
return new_ax
497499

498-
def _setup_subplots(self):
500+
def _setup_subplots(self) -> Figure:
499501
if self.subplots:
500502
naxes = (
501503
self.nseries if isinstance(self.subplots, bool) else len(self.subplots)
@@ -538,8 +540,8 @@ def _setup_subplots(self):
538540
elif self.logy == "sym" or self.loglog == "sym":
539541
[a.set_yscale("symlog") for a in axes]
540542

541-
self.fig = fig
542543
self.axes = axes
544+
return fig
543545

544546
@property
545547
def result(self):
@@ -637,7 +639,7 @@ def _compute_plot_data(self):
637639

638640
self.data = numeric_data.apply(self._convert_to_ndarray)
639641

640-
def _make_plot(self):
642+
def _make_plot(self, fig: Figure):
641643
raise AbstractMethodError(self)
642644

643645
def _add_table(self) -> None:
@@ -672,11 +674,11 @@ def _post_plot_logic_common(self, ax, data):
672674
def _post_plot_logic(self, ax, data) -> None:
673675
"""Post process for each axes. Overridden in child classes"""
674676

675-
def _adorn_subplots(self):
677+
def _adorn_subplots(self, fig: Figure):
676678
"""Common post process unrelated to data"""
677679
if len(self.axes) > 0:
678-
all_axes = self._get_subplots()
679-
nrows, ncols = self._get_axes_layout()
680+
all_axes = self._get_subplots(fig)
681+
nrows, ncols = self._get_axes_layout(fig)
680682
handle_shared_axes(
681683
axarr=all_axes,
682684
nplots=len(all_axes),
@@ -723,7 +725,7 @@ def _adorn_subplots(self):
723725
for ax, title in zip(self.axes, self.title):
724726
ax.set_title(title)
725727
else:
726-
self.fig.suptitle(self.title)
728+
fig.suptitle(self.title)
727729
else:
728730
if is_list_like(self.title):
729731
msg = (
@@ -1114,17 +1116,17 @@ def _get_errorbars(
11141116
errors[kw] = err
11151117
return errors
11161118

1117-
def _get_subplots(self):
1119+
def _get_subplots(self, fig: Figure):
11181120
from matplotlib.axes import Subplot
11191121

11201122
return [
11211123
ax
1122-
for ax in self.fig.get_axes()
1124+
for ax in fig.get_axes()
11231125
if (isinstance(ax, Subplot) and ax.get_subplotspec() is not None)
11241126
]
11251127

1126-
def _get_axes_layout(self) -> tuple[int, int]:
1127-
axes = self._get_subplots()
1128+
def _get_axes_layout(self, fig: Figure) -> tuple[int, int]:
1129+
axes = self._get_subplots(fig)
11281130
x_set = set()
11291131
y_set = set()
11301132
for ax in axes:
@@ -1172,7 +1174,7 @@ def _post_plot_logic(self, ax: Axes, data) -> None:
11721174
ax.set_xlabel(xlabel)
11731175
ax.set_ylabel(ylabel)
11741176

1175-
def _plot_colorbar(self, ax: Axes, **kwds):
1177+
def _plot_colorbar(self, ax: Axes, *, fig: Figure, **kwds):
11761178
# Addresses issues #10611 and #10678:
11771179
# When plotting scatterplots and hexbinplots in IPython
11781180
# inline backend the colorbar axis height tends not to
@@ -1189,7 +1191,7 @@ def _plot_colorbar(self, ax: Axes, **kwds):
11891191
# use the last one which contains the latest information
11901192
# about the ax
11911193
img = ax.collections[-1]
1192-
return self.fig.colorbar(img, ax=ax, **kwds)
1194+
return fig.colorbar(img, ax=ax, **kwds)
11931195

11941196

11951197
class ScatterPlot(PlanePlot):
@@ -1209,7 +1211,7 @@ def __init__(self, data, x, y, s=None, c=None, **kwargs) -> None:
12091211
c = self.data.columns[c]
12101212
self.c = c
12111213

1212-
def _make_plot(self):
1214+
def _make_plot(self, fig: Figure):
12131215
x, y, c, data = self.x, self.y, self.c, self.data
12141216
ax = self.axes[0]
12151217

@@ -1274,7 +1276,7 @@ def _make_plot(self):
12741276
)
12751277
if cb:
12761278
cbar_label = c if c_is_column else ""
1277-
cbar = self._plot_colorbar(ax, label=cbar_label)
1279+
cbar = self._plot_colorbar(ax, fig=fig, label=cbar_label)
12781280
if color_by_categorical:
12791281
cbar.set_ticks(np.linspace(0.5, n_cats - 0.5, n_cats))
12801282
cbar.ax.set_yticklabels(self.data[c].cat.categories)
@@ -1306,7 +1308,7 @@ def __init__(self, data, x, y, C=None, **kwargs) -> None:
13061308
C = self.data.columns[C]
13071309
self.C = C
13081310

1309-
def _make_plot(self) -> None:
1311+
def _make_plot(self, fig: Figure) -> None:
13101312
x, y, data, C = self.x, self.y, self.data, self.C
13111313
ax = self.axes[0]
13121314
# pandas uses colormap, matplotlib uses cmap.
@@ -1321,7 +1323,7 @@ def _make_plot(self) -> None:
13211323

13221324
ax.hexbin(data[x].values, data[y].values, C=c_values, cmap=cmap, **self.kwds)
13231325
if cb:
1324-
self._plot_colorbar(ax)
1326+
self._plot_colorbar(ax, fig=fig)
13251327

13261328
def _make_legend(self) -> None:
13271329
pass
@@ -1358,7 +1360,7 @@ def _is_ts_plot(self) -> bool:
13581360
def _use_dynamic_x(self):
13591361
return use_dynamic_x(self._get_ax(0), self.data)
13601362

1361-
def _make_plot(self) -> None:
1363+
def _make_plot(self, fig: Figure) -> None:
13621364
if self._is_ts_plot():
13631365
data = maybe_convert_index(self._get_ax(0), self.data)
13641366

@@ -1680,7 +1682,7 @@ def _plot( # type: ignore[override]
16801682
def _start_base(self):
16811683
return self.bottom
16821684

1683-
def _make_plot(self) -> None:
1685+
def _make_plot(self, fig: Figure) -> None:
16841686
colors = self._get_colors()
16851687
ncolors = len(colors)
16861688

@@ -1842,7 +1844,7 @@ def _args_adjust(self) -> None:
18421844
def _validate_color_args(self) -> None:
18431845
pass
18441846

1845-
def _make_plot(self) -> None:
1847+
def _make_plot(self, fig: Figure) -> None:
18461848
colors = self._get_colors(num_colors=len(self.data), color_kwds="colors")
18471849
self.kwds.setdefault("colors", colors)
18481850

pandas/plotting/_matplotlib/hist.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
if TYPE_CHECKING:
4141
from matplotlib.axes import Axes
42+
from matplotlib.figure import Figure
4243

4344
from pandas._typing import PlottingOrientation
4445

@@ -113,7 +114,7 @@ def _plot( # type: ignore[override]
113114
cls._update_stacker(ax, stacking_id, n)
114115
return patches
115116

116-
def _make_plot(self) -> None:
117+
def _make_plot(self, fig: Figure) -> None:
117118
colors = self._get_colors()
118119
stacking_id = self._get_stacking_id()
119120

0 commit comments

Comments
 (0)