88import matplotlib .pyplot as plt
99from matplotlib import ticker
1010from matplotlib .colors import Colormap
11- from scipy import stats
1211
1312
1413NONETYPE = type (None )
@@ -20,22 +19,22 @@ def __init__(self, x, y, data, aggfunc, split, row, col,
2019 x_order , y_order , split_order , row_order , col_order ,
2120 orientation , sort_values , wrap , figsize , title , sharex , sharey ,
2221 xlabel , ylabel , xlim , ylim , xscale , yscale , cmap ,
23- x_textwrap , y_textwrap , check_numeric = False ):
22+ x_textwrap , y_textwrap , check_numeric = False , kind = None ):
2423
2524 self .used_columns = set ()
2625 self .data = self .get_data (data )
2726 self .x = self .get_col (x )
2827 self .y = self .get_col (y )
2928 self .validate_x_y ()
3029 self .orientation = orientation
31- self .aggfunc = aggfunc
30+ self .aggfunc = self . get_aggfunc ( aggfunc )
3231 self .groupby = self .get_groupby ()
3332 self .split = self .get_col (split )
3433 self .row = self .get_col (row )
3534 self .col = self .get_col (col )
3635
3736 self .agg = self .set_agg ()
38- self .make_groups_categorical ()
37+ self .make_groups_categorical (kind )
3938 self .validate_numeric (check_numeric )
4039
4140 self .x_order = self .validate_order (x_order , 'x' )
@@ -79,10 +78,14 @@ def __init__(self, x, y, data, aggfunc, split, row, col,
7978 self .final_data = self .get_final_data ()
8079 self .style_fig ()
8180 self .add_ax_titles ()
81+ self .add_fig_title ()
8282
8383 def get_data (self , data ):
84+ if isinstance (data , pd .Series ):
85+ return data .to_frame ()
86+
8487 if not isinstance (data , pd .DataFrame ):
85- raise TypeError ('`data` must be a pandas DataFrame' )
88+ raise TypeError ('`data` must be a pandas DataFrame or Series ' )
8689 elif len (data ) == 0 :
8790 raise ValueError ('DataFrame contains no data' )
8891 return data .copy ()
@@ -104,6 +107,13 @@ def validate_x_y(self):
104107 if self .x == self .y and self .x is not None and self .y is not None :
105108 raise ValueError ('`x` and `y` cannot be the same column name' )
106109
110+ def get_aggfunc (self , aggfunc ):
111+ if aggfunc == 'countna' :
112+ return lambda x : x .isna ().sum ()
113+ if aggfunc == 'percna' :
114+ return lambda x : x .isna ().mean ()
115+ return aggfunc
116+
107117 def get_groupby (self ):
108118 if self .x is None or self .y is None or self .aggfunc is None :
109119 return
@@ -142,12 +152,16 @@ def filter_data(self):
142152 if name and self .data [name ].dtype .name == 'category' :
143153 self .data [name ].cat .remove_unused_categories (inplace = True )
144154
145- def make_groups_categorical (self ):
155+ def make_groups_categorical (self , kind ):
146156 category_cols = [self .groupby , self .split , self .row , self .col ]
147157 for col in category_cols :
148158 if col :
149159 if self .data [col ].dtype .name != 'category' :
150160 self .data [col ] = self .data [col ].astype ('category' )
161+ if kind == 'count' :
162+ col = self .x or self .y
163+ if self .data [col ].dtype .name != 'category' :
164+ self .data [col ] = self .data [col ].astype ('category' )
151165
152166 def validate_numeric (self , check_numeric ):
153167 if check_numeric :
@@ -348,6 +362,7 @@ def get_fig_shape(self):
348362 return nrows , ncols
349363
350364 def get_data_for_every_plot (self ):
365+ # TODO: catch keyerror for groups that dont exist
351366 rows , cols = self .get_row_col_order ()
352367 if self .plot_type == 'row_only' :
353368 return [(row , self .data .loc [row ]) for row in rows ]
@@ -362,7 +377,7 @@ def get_data_for_every_plot(self):
362377 with warnings .catch_warnings ():
363378 warnings .simplefilter ("ignore" )
364379 data = self .data .loc [group ]
365- except KeyError :
380+ except ( KeyError , TypeError ) :
366381 data = self .data .iloc [:0 ]
367382 groups .append ((group , data ))
368383 return groups
@@ -423,7 +438,7 @@ def get_order(self, arr, vals):
423438
424439 def reverse_order (self , order ):
425440 cond1 = order == 'desc' and self .orientation == 'v'
426- cond2 = order == 'asc' and self .orientation == 'h'
441+ cond2 = order in ( 'asc' , None ) and self .orientation == 'h'
427442 return cond1 or cond2
428443
429444 def order_xy (self , x , y ):
@@ -471,7 +486,8 @@ def get_ordered_groups(self, data, specific_order, kind):
471486 order = []
472487 groups = []
473488 sort = specific_order is not None
474- for grp , data_grp in data .groupby (getattr (self , kind ), sort = sort ):
489+ # TODO: Need to decide defaults for x_order, y_order etc... either None or 'asc'
490+ for grp , data_grp in data .groupby (getattr (self , kind ), sort = True ):
475491 order .append ((grp , data_grp ))
476492 groups .append (grp )
477493
@@ -535,12 +551,13 @@ def get_final_groups(self, data, split_label, row_label, col_label):
535551 s = data [col ]
536552 x , y = s .index .values , s .values
537553 x , y = self .get_correct_data_order (x , y )
538- groups .append ((x , y , split_label , col , row_label , col_label ))
554+ x , y = (x , y ) if self .orientation == 'v' else (y , x )
555+ groups .append ((x , y , col , None , row_label , col_label ))
539556 else :
540557 # simple raw plot - make sure to warn when lots of data for bar/box/hist
541558 # one graph per row - OK for scatterplots and line plots
542559 x , y = self .get_correct_data_order (data [self .x ], data [self .y ])
543- groups .append ((x , y , None , None , row_label , col_label ))
560+ groups .append ((x , y , split_label , None , row_label , col_label ))
544561 return groups
545562
546563 def get_final_data (self ):
@@ -635,8 +652,8 @@ def add_ticklabels(self, labels, ax, delta=0):
635652 ax .set_yticks (ticks - delta )
636653 ax .set_yticklabels (labels )
637654
638- def add_legend (self , handles = None , labels = None ):
639- if self . split :
655+ def add_legend (self , label = None , handles = None , labels = None ):
656+ if label is not None :
640657 if handles is None :
641658 handles , labels = self .axs [0 ].get_legend_handles_labels ()
642659 ncol = len (labels ) // 8 + 1
@@ -664,4 +681,7 @@ def update_fig_size(self, n_splits, n_groups_per_split):
664681 height = new_size * .8 * self .fig_shape [0 ]
665682 width = width * self .fig_shape [1 ]
666683 width , height = min (width , 25 ), min (height , 25 )
667- self .fig .set_size_inches (width , height )
684+ self .fig .set_size_inches (width , height )
685+
686+ def add_fig_title (self ):
687+ self .fig .suptitle (self .title , y = 1.02 )
0 commit comments