@@ -49,6 +49,7 @@ def __init__(self, x, y, data, aggfunc, split, row, col,
4949 self .sort_values = sort_values
5050 self .groupby_sort = True
5151 self .wrap = wrap
52+ self .figsize = figsize
5253 self .title = title
5354 self .sharex = sharex
5455 self .sharey = sharey
@@ -64,24 +65,14 @@ def __init__(self, x, y, data, aggfunc, split, row, col,
6465 self .x_rot = x_rot
6566 self .y_rot = y_rot
6667
67- self .validate_args (figsize )
68+ self .validate_args ()
6869 self .plot_type = self .get_plot_type ()
6970 self .agg_kind = self .get_agg_kind ()
7071 self .data = self .set_index ()
7172 self .rows , self .cols = self .get_uniques ()
7273 self .rows , self .cols = self .get_row_col_order ()
7374 self .fig_shape = self .get_fig_shape ()
74- self .user_figsize = figsize is not None
75- self .figsize = self .get_figsize (figsize )
76- self .original_rcParams = plt .rcParams .copy ()
77- self .set_rcParams ()
78- self .fig , self .axs = self .create_figure ()
79- self .set_color_cycle ()
80- self .data_for_plots = self .get_data_for_every_plot ()
81- self .final_data = self .get_final_data ()
82- self .style_fig ()
83- self .add_ax_titles ()
84- self .add_fig_title ()
75+
8576
8677 def get_data (self , data ):
8778 if isinstance (data , pd .Series ):
@@ -225,22 +216,11 @@ def get_colors(self, cmap):
225216 raise TypeError ('`cmap` must be a string name of a colormap, a matplotlib colormap '
226217 'instance, list, or tuple of colors' )
227218
228- def validate_args (self , figsize ):
229- self .validate_figsize (figsize )
219+ def validate_args (self ):
230220 self .validate_plot_args ()
231221 self .validate_mpl_args ()
232222 self .validate_sort_values ()
233223
234- def validate_figsize (self , figsize ):
235- if isinstance (figsize , (list , tuple )):
236- if len (figsize ) != 2 :
237- raise ValueError ('figsize must be a two-item tuple/list' )
238- for val in figsize :
239- if not isinstance (val , (int , float )):
240- raise ValueError ('Each item in figsize must be an integer or a float' )
241- elif figsize is not None :
242- raise TypeError ('figsize must be a two-item tuple' )
243-
244224 def validate_plot_args (self ):
245225 if self .orientation not in ('v' , 'h' ):
246226 raise ValueError ('`orientation` must be either "v" or "h".' )
@@ -397,25 +377,6 @@ def get_labels(self, labels):
397377 return None , str (labels )
398378 return None , None
399379
400- def get_figsize (self , figsize ):
401- if figsize :
402- return figsize
403- else :
404- return self .fig_shape [1 ] * 4 , self .fig_shape [0 ] * 3
405-
406- def create_figure (self ):
407- fig = plt .Figure (tight_layout = True , dpi = 144 , figsize = self .figsize )
408- axs = fig .subplots (* self .fig_shape , sharex = self .sharex , sharey = self .sharey )
409- if self .fig_shape != (1 , 1 ):
410- axs = axs .flatten (order = 'F' )
411- else :
412- axs = [axs ]
413- return fig , axs
414-
415- def set_color_cycle (self ):
416- for ax in self .axs :
417- ax .set_prop_cycle (color = self .colors )
418-
419380 def sort_values_xy (self , x , y ):
420381 grp , num = (x , y ) if self .orientation == 'v' else (y , x )
421382 if self .sort_values is None :
@@ -522,7 +483,7 @@ def get_final_groups(self, data, split_label, row_label, col_label):
522483 else :
523484 col = self .x or self .y
524485 vals = data [col ]
525- groups .append ((vals , split_label , None , row_label , col_label ))
486+ groups .append ((vals , split_label , self . col , row_label , col_label ))
526487 elif self .groupby is not None :
527488 try :
528489 s = data .groupby (self .groupby , sort = self .groupby_sort )[self .agg ].agg (self .aggfunc )
@@ -536,18 +497,18 @@ def get_final_groups(self, data, split_label, row_label, col_label):
536497 x , y = s .index .values , s .values
537498 x , y = (x , y ) if self .orientation == 'v' else (y , x )
538499 x , y = self .get_correct_data_order (x , y )
539- groups .append ((x , y , split_label , None , row_label , col_label ))
500+ groups .append ((x , y , split_label , self . groupby , row_label , col_label ))
540501 elif self .x is None or self .y is None :
541502 if self .x :
542503 s = data [self .x ]
543504 x , y = s .values , s .index .values
544505 x , y = self .get_correct_data_order (x , y )
545- groups .append ((x , y , split_label , None , row_label , col_label ))
506+ groups .append ((x , y , split_label , self . x , row_label , col_label ))
546507 elif self .y :
547508 s = data [self .y ]
548509 x , y = s .index .values , s .values
549510 x , y = self .get_correct_data_order (x , y )
550- groups .append ((x , y , split_label , None , row_label , col_label ))
511+ groups .append ((x , y , split_label , self . y , row_label , col_label ))
551512 else :
552513 # wide data
553514 for col in self .get_wide_columns (data ):
@@ -563,6 +524,76 @@ def get_final_groups(self, data, split_label, row_label, col_label):
563524 groups .append ((x , y , split_label , None , row_label , col_label ))
564525 return groups
565526
527+ def get_x_y_plot (self , x , y ):
528+ x_plot , y_plot = x , y
529+ if x_plot .dtype .kind == 'O' :
530+ x_plot = np .arange (len (x_plot ))
531+ if y_plot .dtype .kind == 'O' :
532+ y_plot = np .arange (len (y_plot ))
533+ return x_plot , y_plot
534+
535+ def get_distribution_data (self , info ):
536+ cur_data = defaultdict (list )
537+ cur_ticklabels = defaultdict (list )
538+ for vals , split_label , col_name , row_label , col_label in info :
539+ cur_data [split_label ].append (vals )
540+ cur_ticklabels [split_label ].append (col_name )
541+ return cur_data , cur_ticklabels
542+
543+
544+ class MPLCommon (CommonPlot ):
545+
546+ def __init__ (self , x , y , data , aggfunc , split , row , col ,
547+ x_order , y_order , split_order , row_order , col_order ,
548+ orientation , sort_values , wrap , figsize , title , sharex , sharey ,
549+ xlabel , ylabel , xlim , ylim , xscale , yscale , cmap ,
550+ x_textwrap , y_textwrap , x_rot , y_rot ,
551+ check_numeric = False , kind = None ):
552+ super ().__init__ (x , y , data , aggfunc , split , row , col ,
553+ x_order , y_order , split_order , row_order , col_order ,
554+ orientation , sort_values , wrap , figsize , title , sharex , sharey ,
555+ xlabel , ylabel , xlim , ylim , xscale , yscale , cmap ,
556+ x_textwrap , y_textwrap , x_rot , y_rot ,
557+ check_numeric = False , kind = None )
558+ self .figsize = self .get_figsize ()
559+ self .user_figsize = self .figsize is not None
560+ self .original_rcParams = plt .rcParams .copy ()
561+ self .set_rcParams ()
562+ self .fig , self .axs = self .create_figure ()
563+ self .set_color_cycle ()
564+ self .data_for_plots = self .get_data_for_every_plot ()
565+ self .final_data = self .get_final_data ()
566+ self .style_fig ()
567+ self .add_ax_titles ()
568+ self .add_fig_title ()
569+
570+ def get_figsize (self ):
571+ if self .figsize is None :
572+ return
573+ elif isinstance (self .figsize , (list , tuple )):
574+ if len (self .figsize ) != 2 :
575+ raise ValueError ('figsize must be a two-item tuple/list' )
576+ for val in self .figsize :
577+ if not isinstance (val , (int , float )):
578+ raise ValueError ('Each item in figsize must be an integer or a float' )
579+ else :
580+ raise TypeError ('figsize must be a two-item tuple' )
581+
582+ return self .fig_shape [1 ] * 4 , self .fig_shape [0 ] * 3
583+
584+ def create_figure (self ):
585+ fig = plt .Figure (tight_layout = True , dpi = 144 , figsize = self .figsize )
586+ axs = fig .subplots (* self .fig_shape , sharex = self .sharex , sharey = self .sharey )
587+ if self .fig_shape != (1 , 1 ):
588+ axs = axs .flatten (order = 'F' )
589+ else :
590+ axs = [axs ]
591+ return fig , axs
592+
593+ def set_color_cycle (self ):
594+ for ax in self .axs :
595+ ax .set_prop_cycle (color = self .colors )
596+
566597 def get_final_data (self ):
567598 # create list of data for each call to plotting method
568599 final_data = defaultdict (list )
@@ -627,22 +658,6 @@ def set_rcParams(self):
627658 plt .rcParams ['font.size' ] = 6
628659 plt .rcParams ['font.family' ] = 'Helvetica'
629660
630- def get_x_y_plot (self , x , y ):
631- x_plot , y_plot = x , y
632- if x_plot .dtype .kind == 'O' :
633- x_plot = np .arange (len (x_plot ))
634- if y_plot .dtype .kind == 'O' :
635- y_plot = np .arange (len (y_plot ))
636- return x_plot , y_plot
637-
638- def get_distribution_data (self , info ):
639- cur_data = defaultdict (list )
640- cur_ticklabels = defaultdict (list )
641- for vals , split_label , col_name , row_label , col_label in info :
642- cur_data [split_label ].append (vals )
643- cur_ticklabels [split_label ].append (col_name )
644- return cur_data , cur_ticklabels
645-
646661 def add_ticklabels (self , labels , ax , delta = 0 ):
647662 ticks = np .arange (len (labels ))
648663 ha , va = 'center' , 'center'
@@ -700,3 +715,177 @@ def update_fig_size(self, n_splits, n_groups_per_split):
700715
701716 def add_fig_title (self ):
702717 self .fig .suptitle (self .title , y = 1.02 )
718+
719+
720+ import plotly .graph_objects as go
721+ from plotly .subplots import make_subplots
722+
723+
724+ class PlotlyCommon (CommonPlot ):
725+
726+ def __init__ (self , x , y , data , aggfunc , split , row , col ,
727+ x_order , y_order , split_order , row_order , col_order ,
728+ orientation , sort_values , wrap , figsize , title , sharex , sharey ,
729+ xlabel , ylabel , xlim , ylim , xscale , yscale , cmap ,
730+ x_textwrap , y_textwrap , x_rot , y_rot ,
731+ check_numeric = False , kind = None ):
732+ super ().__init__ (x , y , data , aggfunc , split , row , col ,
733+ x_order , y_order , split_order , row_order , col_order ,
734+ orientation , sort_values , wrap , figsize , title , sharex , sharey ,
735+ xlabel , ylabel , xlim , ylim , xscale , yscale , cmap ,
736+ x_textwrap , y_textwrap , x_rot , y_rot ,
737+ check_numeric = False , kind = None )
738+
739+ self .data_for_plots = self .get_data_for_every_plot ()
740+ self .final_data = self .get_final_data ()
741+ self .fig = self .create_figure ()
742+
743+ def create_figure (self ):
744+ titles = self .get_subplot_titles ()
745+ fig = make_subplots (rows = self .fig_shape [0 ], cols = self .fig_shape [1 ], subplot_titles = titles ,
746+ shared_xaxes = self .sharex , shared_yaxes = self .sharey ,
747+ horizontal_spacing = .03 )
748+ fig .update_layout (title_text = self .title , legend_title_text = self .split )
749+ return fig
750+
751+ def get_final_data (self ):
752+ # create list of data for each call to plotting method
753+ final_data = defaultdict (list )
754+ locs = []
755+ for i in range (self .fig_shape [0 ]):
756+ for j in range (self .fig_shape [1 ]):
757+ locs .append ((i + 1 , j + 1 ))
758+
759+ for (labels , data ), loc in zip (self .data_for_plots , locs ):
760+ row_label , col_label = self .get_labels (labels )
761+ if self .split :
762+ for grp , data_grp in self .get_ordered_groups (data , self .split_order , 'split' ):
763+ final_data [loc ].extend (self .get_final_groups (data_grp , grp , row_label , col_label ))
764+ else :
765+ final_data [loc ].extend (self .get_final_groups (data , None , row_label , col_label ))
766+ return final_data
767+
768+ def get_subplot_titles (self ):
769+ titles = []
770+ for (i , j ), info in self .final_data .items ():
771+ row_label , col_label = info [0 ][- 2 :]
772+ if row_label is not None :
773+ row_label = str (row_label )
774+ if col_label is not None :
775+ col_label = str (col_label )
776+ row_label = row_label or ''
777+ col_label = col_label or ''
778+ if row_label and col_label :
779+ title = row_label + ' - ' + col_label
780+ else :
781+ title = row_label or col_label
782+ title = textwrap .fill (str (title ), 30 )
783+ titles .append (title )
784+ return titles
785+
786+
787+ class CountCommon (CommonPlot ):
788+
789+ def get_count_dict (self , normalize ):
790+ count_dict = {}
791+
792+ if isinstance (normalize , str ):
793+ if normalize in (val , self .split , self .row , self .col ):
794+ normalize = [normalize ]
795+
796+ if isinstance (normalize , tuple ):
797+ normalize = list (normalize )
798+ elif hasattr (normalize , 'tolist' ):
799+ normalize = normalize .tolist ()
800+ elif not isinstance (normalize , (bool , list )):
801+ raise ValueError ('`normalize` must either be `True`/`False`, one of the columns passed '
802+ 'to `val`, `split`, `row` or `col`, or a list of '
803+ 'those columns' )
804+ normalize_kind = None
805+ if isinstance (normalize , list ):
806+ row_col = []
807+ val_split = []
808+ for col in normalize :
809+ if col in (self .row , self .col ):
810+ row_col .append (col )
811+ elif col in (val , self .split ):
812+ val_split .append (col )
813+ else :
814+ raise ValueError ('Columns passed to `normalize` must be the same as '
815+ ' `val`, `split`, `row` or `col`.' )
816+
817+ if row_col :
818+ all_counts = {}
819+ for grp , data in self .data .groupby (row_col ):
820+ if len (row_col ) == 1 :
821+ grp = str (grp )
822+ else :
823+ grp = tuple (str (g ) for g in grp )
824+
825+ if val_split :
826+ normalize_kind = 'all'
827+ all_counts [grp ] = data .groupby (val_split ).size ()
828+ else :
829+ normalize_kind = 'grid'
830+ all_counts [grp ] = len (data )
831+ else :
832+ normalize_kind = 'single'
833+ all_counts = self .data .groupby (val_split ).size ()
834+
835+ n = 0
836+ for key , info in self .final_data .items ():
837+ columns = []
838+ vcs = []
839+ for vals , split_label , col_name , row_label , col_label in info :
840+ vcs .append (vals .value_counts ())
841+ columns .append (split_label )
842+
843+ df = pd .concat (vcs , axis = 1 )
844+ df .columns = columns
845+ df .index .name = vals .name
846+ if normalize_kind == 'single' :
847+ if len (val_split ) == 2 :
848+ df = df / all_counts .unstack (self .split )
849+ elif df .index .name == all_counts .index .name :
850+ df = df .div (all_counts , axis = 0 )
851+ else :
852+ df = df / all_counts
853+ elif normalize_kind in ('grid' , 'all' ):
854+ grp = []
855+ for col in normalize :
856+ if col == self .row :
857+ grp .append (row_label )
858+ if col == self .col :
859+ grp .append (col_label )
860+
861+ if len (grp ) == 1 :
862+ grp = grp [0 ]
863+ else :
864+ grp = tuple (grp )
865+ grp_val = all_counts [grp ]
866+
867+ if normalize_kind == 'grid' :
868+ df = df / grp_val
869+ elif len (val_split ) == 2 :
870+ df = df / grp_val .unstack (self .split )
871+ elif df .index .name == grp_val .index .name :
872+ df = df .div (grp_val , axis = 0 )
873+ else :
874+ df = df / grp_val
875+
876+ else :
877+ n += df .sum ().sum ()
878+ count_dict [key ] = df
879+
880+ if normalize is True :
881+ count_dict = {key : df / n for key , df in count_dict .items ()}
882+
883+ return count_dict
884+
885+
886+ class MPLCount (CountCommon , MPLCommon ):
887+ pass
888+
889+
890+ class PlotlyCount (CountCommon , PlotlyCommon ):
891+ pass
0 commit comments