80
80
from matplotlib .artist import Artist
81
81
from matplotlib .axes import Axes
82
82
from matplotlib .axis import Axis
83
+ from matplotlib .figure import Figure
83
84
84
85
from pandas ._typing import (
85
86
IndexLabel ,
@@ -241,7 +242,8 @@ def __init__(
241
242
self .stacked = kwds .pop ("stacked" , False )
242
243
243
244
self .ax = ax
244
- self .fig = fig
245
+ # TODO: deprecate fig keyword as it is ignored, not passed in tests
246
+ # as of 2023-11-05
245
247
self .axes = np .array ([], dtype = object ) # "real" version get set in `generate`
246
248
247
249
# parse errorbar input if given
@@ -449,11 +451,11 @@ def draw(self) -> None:
449
451
def generate (self ) -> None :
450
452
self ._args_adjust ()
451
453
self ._compute_plot_data ()
452
- self ._setup_subplots ()
453
- self ._make_plot ()
454
+ fig = self ._setup_subplots ()
455
+ self ._make_plot (fig )
454
456
self ._add_table ()
455
457
self ._make_legend ()
456
- self ._adorn_subplots ()
458
+ self ._adorn_subplots (fig )
457
459
458
460
for ax in self .axes :
459
461
self ._post_plot_logic_common (ax , self .data )
@@ -495,7 +497,7 @@ def _maybe_right_yaxis(self, ax: Axes, axes_num):
495
497
new_ax .set_yscale ("symlog" )
496
498
return new_ax
497
499
498
- def _setup_subplots (self ):
500
+ def _setup_subplots (self ) -> Figure :
499
501
if self .subplots :
500
502
naxes = (
501
503
self .nseries if isinstance (self .subplots , bool ) else len (self .subplots )
@@ -538,8 +540,8 @@ def _setup_subplots(self):
538
540
elif self .logy == "sym" or self .loglog == "sym" :
539
541
[a .set_yscale ("symlog" ) for a in axes ]
540
542
541
- self .fig = fig
542
543
self .axes = axes
544
+ return fig
543
545
544
546
@property
545
547
def result (self ):
@@ -637,7 +639,7 @@ def _compute_plot_data(self):
637
639
638
640
self .data = numeric_data .apply (self ._convert_to_ndarray )
639
641
640
- def _make_plot (self ):
642
+ def _make_plot (self , fig : Figure ):
641
643
raise AbstractMethodError (self )
642
644
643
645
def _add_table (self ) -> None :
@@ -672,11 +674,11 @@ def _post_plot_logic_common(self, ax, data):
672
674
def _post_plot_logic (self , ax , data ) -> None :
673
675
"""Post process for each axes. Overridden in child classes"""
674
676
675
- def _adorn_subplots (self ):
677
+ def _adorn_subplots (self , fig : Figure ):
676
678
"""Common post process unrelated to data"""
677
679
if len (self .axes ) > 0 :
678
- all_axes = self ._get_subplots ()
679
- nrows , ncols = self ._get_axes_layout ()
680
+ all_axes = self ._get_subplots (fig )
681
+ nrows , ncols = self ._get_axes_layout (fig )
680
682
handle_shared_axes (
681
683
axarr = all_axes ,
682
684
nplots = len (all_axes ),
@@ -723,7 +725,7 @@ def _adorn_subplots(self):
723
725
for ax , title in zip (self .axes , self .title ):
724
726
ax .set_title (title )
725
727
else :
726
- self . fig .suptitle (self .title )
728
+ fig .suptitle (self .title )
727
729
else :
728
730
if is_list_like (self .title ):
729
731
msg = (
@@ -1114,17 +1116,17 @@ def _get_errorbars(
1114
1116
errors [kw ] = err
1115
1117
return errors
1116
1118
1117
- def _get_subplots (self ):
1119
+ def _get_subplots (self , fig : Figure ):
1118
1120
from matplotlib .axes import Subplot
1119
1121
1120
1122
return [
1121
1123
ax
1122
- for ax in self . fig .get_axes ()
1124
+ for ax in fig .get_axes ()
1123
1125
if (isinstance (ax , Subplot ) and ax .get_subplotspec () is not None )
1124
1126
]
1125
1127
1126
- def _get_axes_layout (self ) -> tuple [int , int ]:
1127
- axes = self ._get_subplots ()
1128
+ def _get_axes_layout (self , fig : Figure ) -> tuple [int , int ]:
1129
+ axes = self ._get_subplots (fig )
1128
1130
x_set = set ()
1129
1131
y_set = set ()
1130
1132
for ax in axes :
@@ -1172,7 +1174,7 @@ def _post_plot_logic(self, ax: Axes, data) -> None:
1172
1174
ax .set_xlabel (xlabel )
1173
1175
ax .set_ylabel (ylabel )
1174
1176
1175
- def _plot_colorbar (self , ax : Axes , ** kwds ):
1177
+ def _plot_colorbar (self , ax : Axes , * , fig : Figure , * *kwds ):
1176
1178
# Addresses issues #10611 and #10678:
1177
1179
# When plotting scatterplots and hexbinplots in IPython
1178
1180
# inline backend the colorbar axis height tends not to
@@ -1189,7 +1191,7 @@ def _plot_colorbar(self, ax: Axes, **kwds):
1189
1191
# use the last one which contains the latest information
1190
1192
# about the ax
1191
1193
img = ax .collections [- 1 ]
1192
- return self . fig .colorbar (img , ax = ax , ** kwds )
1194
+ return fig .colorbar (img , ax = ax , ** kwds )
1193
1195
1194
1196
1195
1197
class ScatterPlot (PlanePlot ):
@@ -1209,7 +1211,7 @@ def __init__(self, data, x, y, s=None, c=None, **kwargs) -> None:
1209
1211
c = self .data .columns [c ]
1210
1212
self .c = c
1211
1213
1212
- def _make_plot (self ):
1214
+ def _make_plot (self , fig : Figure ):
1213
1215
x , y , c , data = self .x , self .y , self .c , self .data
1214
1216
ax = self .axes [0 ]
1215
1217
@@ -1274,7 +1276,7 @@ def _make_plot(self):
1274
1276
)
1275
1277
if cb :
1276
1278
cbar_label = c if c_is_column else ""
1277
- cbar = self ._plot_colorbar (ax , label = cbar_label )
1279
+ cbar = self ._plot_colorbar (ax , fig = fig , label = cbar_label )
1278
1280
if color_by_categorical :
1279
1281
cbar .set_ticks (np .linspace (0.5 , n_cats - 0.5 , n_cats ))
1280
1282
cbar .ax .set_yticklabels (self .data [c ].cat .categories )
@@ -1306,7 +1308,7 @@ def __init__(self, data, x, y, C=None, **kwargs) -> None:
1306
1308
C = self .data .columns [C ]
1307
1309
self .C = C
1308
1310
1309
- def _make_plot (self ) -> None :
1311
+ def _make_plot (self , fig : Figure ) -> None :
1310
1312
x , y , data , C = self .x , self .y , self .data , self .C
1311
1313
ax = self .axes [0 ]
1312
1314
# pandas uses colormap, matplotlib uses cmap.
@@ -1321,7 +1323,7 @@ def _make_plot(self) -> None:
1321
1323
1322
1324
ax .hexbin (data [x ].values , data [y ].values , C = c_values , cmap = cmap , ** self .kwds )
1323
1325
if cb :
1324
- self ._plot_colorbar (ax )
1326
+ self ._plot_colorbar (ax , fig = fig )
1325
1327
1326
1328
def _make_legend (self ) -> None :
1327
1329
pass
@@ -1358,7 +1360,7 @@ def _is_ts_plot(self) -> bool:
1358
1360
def _use_dynamic_x (self ):
1359
1361
return use_dynamic_x (self ._get_ax (0 ), self .data )
1360
1362
1361
- def _make_plot (self ) -> None :
1363
+ def _make_plot (self , fig : Figure ) -> None :
1362
1364
if self ._is_ts_plot ():
1363
1365
data = maybe_convert_index (self ._get_ax (0 ), self .data )
1364
1366
@@ -1680,7 +1682,7 @@ def _plot( # type: ignore[override]
1680
1682
def _start_base (self ):
1681
1683
return self .bottom
1682
1684
1683
- def _make_plot (self ) -> None :
1685
+ def _make_plot (self , fig : Figure ) -> None :
1684
1686
colors = self ._get_colors ()
1685
1687
ncolors = len (colors )
1686
1688
@@ -1842,7 +1844,7 @@ def _args_adjust(self) -> None:
1842
1844
def _validate_color_args (self ) -> None :
1843
1845
pass
1844
1846
1845
- def _make_plot (self ) -> None :
1847
+ def _make_plot (self , fig : Figure ) -> None :
1846
1848
colors = self ._get_colors (num_colors = len (self .data ), color_kwds = "colors" )
1847
1849
self .kwds .setdefault ("colors" , colors )
1848
1850
0 commit comments