Skip to content

Commit 3e75f65

Browse files
committed
refactor
1 parent 9607aaf commit 3e75f65

File tree

2 files changed

+513
-393
lines changed

2 files changed

+513
-393
lines changed

dexplot/_common_plot.py

Lines changed: 252 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, x, y, data, aggfunc, split, row, col,
4949
self.sort_values = sort_values
5050
self.groupby_sort = True
5151
self.wrap = wrap
52+
self.figsize = figsize
5253
self.title = title
5354
self.sharex = sharex
5455
self.sharey = sharey
@@ -64,24 +65,14 @@ def __init__(self, x, y, data, aggfunc, split, row, col,
6465
self.x_rot = x_rot
6566
self.y_rot = y_rot
6667

67-
self.validate_args(figsize)
68+
self.validate_args()
6869
self.plot_type = self.get_plot_type()
6970
self.agg_kind = self.get_agg_kind()
7071
self.data = self.set_index()
7172
self.rows, self.cols = self.get_uniques()
7273
self.rows, self.cols = self.get_row_col_order()
7374
self.fig_shape = self.get_fig_shape()
74-
self.user_figsize = figsize is not None
75-
self.figsize = self.get_figsize(figsize)
76-
self.original_rcParams = plt.rcParams.copy()
77-
self.set_rcParams()
78-
self.fig, self.axs = self.create_figure()
79-
self.set_color_cycle()
80-
self.data_for_plots = self.get_data_for_every_plot()
81-
self.final_data = self.get_final_data()
82-
self.style_fig()
83-
self.add_ax_titles()
84-
self.add_fig_title()
75+
8576

8677
def get_data(self, data):
8778
if isinstance(data, pd.Series):
@@ -225,22 +216,11 @@ def get_colors(self, cmap):
225216
raise TypeError('`cmap` must be a string name of a colormap, a matplotlib colormap '
226217
'instance, list, or tuple of colors')
227218

228-
def validate_args(self, figsize):
229-
self.validate_figsize(figsize)
219+
def validate_args(self):
230220
self.validate_plot_args()
231221
self.validate_mpl_args()
232222
self.validate_sort_values()
233223

234-
def validate_figsize(self, figsize):
235-
if isinstance(figsize, (list, tuple)):
236-
if len(figsize) != 2:
237-
raise ValueError('figsize must be a two-item tuple/list')
238-
for val in figsize:
239-
if not isinstance(val, (int, float)):
240-
raise ValueError('Each item in figsize must be an integer or a float')
241-
elif figsize is not None:
242-
raise TypeError('figsize must be a two-item tuple')
243-
244224
def validate_plot_args(self):
245225
if self.orientation not in ('v', 'h'):
246226
raise ValueError('`orientation` must be either "v" or "h".')
@@ -397,25 +377,6 @@ def get_labels(self, labels):
397377
return None, str(labels)
398378
return None, None
399379

400-
def get_figsize(self, figsize):
401-
if figsize:
402-
return figsize
403-
else:
404-
return self.fig_shape[1] * 4, self.fig_shape[0] * 3
405-
406-
def create_figure(self):
407-
fig = plt.Figure(tight_layout=True, dpi=144, figsize=self.figsize)
408-
axs = fig.subplots(*self.fig_shape, sharex=self.sharex, sharey=self.sharey)
409-
if self.fig_shape != (1, 1):
410-
axs = axs.flatten(order='F')
411-
else:
412-
axs = [axs]
413-
return fig, axs
414-
415-
def set_color_cycle(self):
416-
for ax in self.axs:
417-
ax.set_prop_cycle(color=self.colors)
418-
419380
def sort_values_xy(self, x, y):
420381
grp, num = (x, y) if self.orientation == 'v' else (y, x)
421382
if self.sort_values is None:
@@ -522,7 +483,7 @@ def get_final_groups(self, data, split_label, row_label, col_label):
522483
else:
523484
col = self.x or self.y
524485
vals = data[col]
525-
groups.append((vals, split_label, None, row_label, col_label))
486+
groups.append((vals, split_label, self.col, row_label, col_label))
526487
elif self.groupby is not None:
527488
try:
528489
s = data.groupby(self.groupby, sort=self.groupby_sort)[self.agg].agg(self.aggfunc)
@@ -536,18 +497,18 @@ def get_final_groups(self, data, split_label, row_label, col_label):
536497
x, y = s.index.values, s.values
537498
x, y = (x, y) if self.orientation == 'v' else (y, x)
538499
x, y = self.get_correct_data_order(x, y)
539-
groups.append((x, y, split_label, None, row_label, col_label))
500+
groups.append((x, y, split_label, self.groupby, row_label, col_label))
540501
elif self.x is None or self.y is None:
541502
if self.x:
542503
s = data[self.x]
543504
x, y = s.values, s.index.values
544505
x, y = self.get_correct_data_order(x, y)
545-
groups.append((x, y, split_label, None, row_label, col_label))
506+
groups.append((x, y, split_label, self.x, row_label, col_label))
546507
elif self.y:
547508
s = data[self.y]
548509
x, y = s.index.values, s.values
549510
x, y = self.get_correct_data_order(x, y)
550-
groups.append((x, y, split_label, None, row_label, col_label))
511+
groups.append((x, y, split_label, self.y, row_label, col_label))
551512
else:
552513
# wide data
553514
for col in self.get_wide_columns(data):
@@ -563,6 +524,76 @@ def get_final_groups(self, data, split_label, row_label, col_label):
563524
groups.append((x, y, split_label, None, row_label, col_label))
564525
return groups
565526

527+
def get_x_y_plot(self, x, y):
528+
x_plot, y_plot = x, y
529+
if x_plot.dtype.kind == 'O':
530+
x_plot = np.arange(len(x_plot))
531+
if y_plot.dtype.kind == 'O':
532+
y_plot = np.arange(len(y_plot))
533+
return x_plot, y_plot
534+
535+
def get_distribution_data(self, info):
536+
cur_data = defaultdict(list)
537+
cur_ticklabels = defaultdict(list)
538+
for vals, split_label, col_name, row_label, col_label in info:
539+
cur_data[split_label].append(vals)
540+
cur_ticklabels[split_label].append(col_name)
541+
return cur_data, cur_ticklabels
542+
543+
544+
class MPLCommon(CommonPlot):
545+
546+
def __init__(self, x, y, data, aggfunc, split, row, col,
547+
x_order, y_order, split_order, row_order, col_order,
548+
orientation, sort_values, wrap, figsize, title, sharex, sharey,
549+
xlabel, ylabel, xlim, ylim, xscale, yscale, cmap,
550+
x_textwrap, y_textwrap, x_rot, y_rot,
551+
check_numeric=False, kind=None):
552+
super().__init__(x, y, data, aggfunc, split, row, col,
553+
x_order, y_order, split_order, row_order, col_order,
554+
orientation, sort_values, wrap, figsize, title, sharex, sharey,
555+
xlabel, ylabel, xlim, ylim, xscale, yscale, cmap,
556+
x_textwrap, y_textwrap, x_rot, y_rot,
557+
check_numeric=False, kind=None)
558+
self.figsize = self.get_figsize()
559+
self.user_figsize = self.figsize is not None
560+
self.original_rcParams = plt.rcParams.copy()
561+
self.set_rcParams()
562+
self.fig, self.axs = self.create_figure()
563+
self.set_color_cycle()
564+
self.data_for_plots = self.get_data_for_every_plot()
565+
self.final_data = self.get_final_data()
566+
self.style_fig()
567+
self.add_ax_titles()
568+
self.add_fig_title()
569+
570+
def get_figsize(self):
571+
if self.figsize is None:
572+
return
573+
elif isinstance(self.figsize, (list, tuple)):
574+
if len(self.figsize) != 2:
575+
raise ValueError('figsize must be a two-item tuple/list')
576+
for val in self.figsize:
577+
if not isinstance(val, (int, float)):
578+
raise ValueError('Each item in figsize must be an integer or a float')
579+
else:
580+
raise TypeError('figsize must be a two-item tuple')
581+
582+
return self.fig_shape[1] * 4, self.fig_shape[0] * 3
583+
584+
def create_figure(self):
585+
fig = plt.Figure(tight_layout=True, dpi=144, figsize=self.figsize)
586+
axs = fig.subplots(*self.fig_shape, sharex=self.sharex, sharey=self.sharey)
587+
if self.fig_shape != (1, 1):
588+
axs = axs.flatten(order='F')
589+
else:
590+
axs = [axs]
591+
return fig, axs
592+
593+
def set_color_cycle(self):
594+
for ax in self.axs:
595+
ax.set_prop_cycle(color=self.colors)
596+
566597
def get_final_data(self):
567598
# create list of data for each call to plotting method
568599
final_data = defaultdict(list)
@@ -627,22 +658,6 @@ def set_rcParams(self):
627658
plt.rcParams['font.size'] = 6
628659
plt.rcParams['font.family'] = 'Helvetica'
629660

630-
def get_x_y_plot(self, x, y):
631-
x_plot, y_plot = x, y
632-
if x_plot.dtype.kind == 'O':
633-
x_plot = np.arange(len(x_plot))
634-
if y_plot.dtype.kind == 'O':
635-
y_plot = np.arange(len(y_plot))
636-
return x_plot, y_plot
637-
638-
def get_distribution_data(self, info):
639-
cur_data = defaultdict(list)
640-
cur_ticklabels = defaultdict(list)
641-
for vals, split_label, col_name, row_label, col_label in info:
642-
cur_data[split_label].append(vals)
643-
cur_ticklabels[split_label].append(col_name)
644-
return cur_data, cur_ticklabels
645-
646661
def add_ticklabels(self, labels, ax, delta=0):
647662
ticks = np.arange(len(labels))
648663
ha, va = 'center', 'center'
@@ -700,3 +715,177 @@ def update_fig_size(self, n_splits, n_groups_per_split):
700715

701716
def add_fig_title(self):
702717
self.fig.suptitle(self.title, y=1.02)
718+
719+
720+
import plotly.graph_objects as go
721+
from plotly.subplots import make_subplots
722+
723+
724+
class PlotlyCommon(CommonPlot):
725+
726+
def __init__(self, x, y, data, aggfunc, split, row, col,
727+
x_order, y_order, split_order, row_order, col_order,
728+
orientation, sort_values, wrap, figsize, title, sharex, sharey,
729+
xlabel, ylabel, xlim, ylim, xscale, yscale, cmap,
730+
x_textwrap, y_textwrap, x_rot, y_rot,
731+
check_numeric=False, kind=None):
732+
super().__init__(x, y, data, aggfunc, split, row, col,
733+
x_order, y_order, split_order, row_order, col_order,
734+
orientation, sort_values, wrap, figsize, title, sharex, sharey,
735+
xlabel, ylabel, xlim, ylim, xscale, yscale, cmap,
736+
x_textwrap, y_textwrap, x_rot, y_rot,
737+
check_numeric=False, kind=None)
738+
739+
self.data_for_plots = self.get_data_for_every_plot()
740+
self.final_data = self.get_final_data()
741+
self.fig = self.create_figure()
742+
743+
def create_figure(self):
744+
titles = self.get_subplot_titles()
745+
fig = make_subplots(rows=self.fig_shape[0], cols=self.fig_shape[1], subplot_titles=titles,
746+
shared_xaxes=self.sharex, shared_yaxes=self.sharey,
747+
horizontal_spacing=.03)
748+
fig.update_layout(title_text=self.title, legend_title_text=self.split)
749+
return fig
750+
751+
def get_final_data(self):
752+
# create list of data for each call to plotting method
753+
final_data = defaultdict(list)
754+
locs = []
755+
for i in range(self.fig_shape[0]):
756+
for j in range(self.fig_shape[1]):
757+
locs.append((i + 1, j + 1))
758+
759+
for (labels, data), loc in zip(self.data_for_plots, locs):
760+
row_label, col_label = self.get_labels(labels)
761+
if self.split:
762+
for grp, data_grp in self.get_ordered_groups(data, self.split_order, 'split'):
763+
final_data[loc].extend(self.get_final_groups(data_grp, grp, row_label, col_label))
764+
else:
765+
final_data[loc].extend(self.get_final_groups(data, None, row_label, col_label))
766+
return final_data
767+
768+
def get_subplot_titles(self):
769+
titles = []
770+
for (i, j), info in self.final_data.items():
771+
row_label, col_label = info[0][-2:]
772+
if row_label is not None:
773+
row_label = str(row_label)
774+
if col_label is not None:
775+
col_label = str(col_label)
776+
row_label = row_label or ''
777+
col_label = col_label or ''
778+
if row_label and col_label:
779+
title = row_label + ' - ' + col_label
780+
else:
781+
title = row_label or col_label
782+
title = textwrap.fill(str(title), 30)
783+
titles.append(title)
784+
return titles
785+
786+
787+
class CountCommon(CommonPlot):
788+
789+
def get_count_dict(self, normalize):
790+
count_dict = {}
791+
792+
if isinstance(normalize, str):
793+
if normalize in (val, self.split, self.row, self.col):
794+
normalize = [normalize]
795+
796+
if isinstance(normalize, tuple):
797+
normalize = list(normalize)
798+
elif hasattr(normalize, 'tolist'):
799+
normalize = normalize.tolist()
800+
elif not isinstance(normalize, (bool, list)):
801+
raise ValueError('`normalize` must either be `True`/`False`, one of the columns passed '
802+
'to `val`, `split`, `row` or `col`, or a list of '
803+
'those columns')
804+
normalize_kind = None
805+
if isinstance(normalize, list):
806+
row_col = []
807+
val_split = []
808+
for col in normalize:
809+
if col in (self.row, self.col):
810+
row_col.append(col)
811+
elif col in (val, self.split):
812+
val_split.append(col)
813+
else:
814+
raise ValueError('Columns passed to `normalize` must be the same as '
815+
' `val`, `split`, `row` or `col`.')
816+
817+
if row_col:
818+
all_counts = {}
819+
for grp, data in self.data.groupby(row_col):
820+
if len(row_col) == 1:
821+
grp = str(grp)
822+
else:
823+
grp = tuple(str(g) for g in grp)
824+
825+
if val_split:
826+
normalize_kind = 'all'
827+
all_counts[grp] = data.groupby(val_split).size()
828+
else:
829+
normalize_kind = 'grid'
830+
all_counts[grp] = len(data)
831+
else:
832+
normalize_kind = 'single'
833+
all_counts = self.data.groupby(val_split).size()
834+
835+
n = 0
836+
for key, info in self.final_data.items():
837+
columns = []
838+
vcs = []
839+
for vals, split_label, col_name, row_label, col_label in info:
840+
vcs.append(vals.value_counts())
841+
columns.append(split_label)
842+
843+
df = pd.concat(vcs, axis=1)
844+
df.columns = columns
845+
df.index.name = vals.name
846+
if normalize_kind == 'single':
847+
if len(val_split) == 2:
848+
df = df / all_counts.unstack(self.split)
849+
elif df.index.name == all_counts.index.name:
850+
df = df.div(all_counts, axis=0)
851+
else:
852+
df = df / all_counts
853+
elif normalize_kind in ('grid', 'all'):
854+
grp = []
855+
for col in normalize:
856+
if col == self.row:
857+
grp.append(row_label)
858+
if col == self.col:
859+
grp.append(col_label)
860+
861+
if len(grp) == 1:
862+
grp = grp[0]
863+
else:
864+
grp = tuple(grp)
865+
grp_val = all_counts[grp]
866+
867+
if normalize_kind == 'grid':
868+
df = df / grp_val
869+
elif len(val_split) == 2:
870+
df = df / grp_val.unstack(self.split)
871+
elif df.index.name == grp_val.index.name:
872+
df = df.div(grp_val, axis=0)
873+
else:
874+
df = df / grp_val
875+
876+
else:
877+
n += df.sum().sum()
878+
count_dict[key] = df
879+
880+
if normalize is True:
881+
count_dict = {key: df / n for key, df in count_dict.items()}
882+
883+
return count_dict
884+
885+
886+
class MPLCount(CountCommon, MPLCommon):
887+
pass
888+
889+
890+
class PlotlyCount(CountCommon, PlotlyCommon):
891+
pass

0 commit comments

Comments
 (0)