Skip to content

Commit 895e4c2

Browse files
committed
updates
1 parent b67d76e commit 895e4c2

File tree

3 files changed

+376
-178
lines changed

3 files changed

+376
-178
lines changed

dexplot/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
from ._utils import load_dataset
33
from . import colors
44

5-
__version__ = '0.0.10'
5+
__version__ = '0.1.0'

dexplot/_common_plot.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import matplotlib.pyplot as plt
99
from matplotlib import ticker
1010
from matplotlib.colors import Colormap
11-
from scipy import stats
1211

1312

1413
NONETYPE = type(None)
@@ -20,22 +19,22 @@ def __init__(self, x, y, data, aggfunc, split, row, col,
2019
x_order, y_order, split_order, row_order, col_order,
2120
orientation, sort_values, wrap, figsize, title, sharex, sharey,
2221
xlabel, ylabel, xlim, ylim, xscale, yscale, cmap,
23-
x_textwrap, y_textwrap, check_numeric=False):
22+
x_textwrap, y_textwrap, check_numeric=False, kind=None):
2423

2524
self.used_columns = set()
2625
self.data = self.get_data(data)
2726
self.x = self.get_col(x)
2827
self.y = self.get_col(y)
2928
self.validate_x_y()
3029
self.orientation = orientation
31-
self.aggfunc = aggfunc
30+
self.aggfunc = self.get_aggfunc(aggfunc)
3231
self.groupby = self.get_groupby()
3332
self.split = self.get_col(split)
3433
self.row = self.get_col(row)
3534
self.col = self.get_col(col)
3635

3736
self.agg = self.set_agg()
38-
self.make_groups_categorical()
37+
self.make_groups_categorical(kind)
3938
self.validate_numeric(check_numeric)
4039

4140
self.x_order = self.validate_order(x_order, 'x')
@@ -79,10 +78,14 @@ def __init__(self, x, y, data, aggfunc, split, row, col,
7978
self.final_data = self.get_final_data()
8079
self.style_fig()
8180
self.add_ax_titles()
81+
self.add_fig_title()
8282

8383
def get_data(self, data):
84+
if isinstance(data, pd.Series):
85+
return data.to_frame()
86+
8487
if not isinstance(data, pd.DataFrame):
85-
raise TypeError('`data` must be a pandas DataFrame')
88+
raise TypeError('`data` must be a pandas DataFrame or Series')
8689
elif len(data) == 0:
8790
raise ValueError('DataFrame contains no data')
8891
return data.copy()
@@ -104,6 +107,13 @@ def validate_x_y(self):
104107
if self.x == self.y and self.x is not None and self.y is not None:
105108
raise ValueError('`x` and `y` cannot be the same column name')
106109

110+
def get_aggfunc(self, aggfunc):
111+
if aggfunc == 'countna':
112+
return lambda x: x.isna().sum()
113+
if aggfunc == 'percna':
114+
return lambda x: x.isna().mean()
115+
return aggfunc
116+
107117
def get_groupby(self):
108118
if self.x is None or self.y is None or self.aggfunc is None:
109119
return
@@ -142,12 +152,16 @@ def filter_data(self):
142152
if name and self.data[name].dtype.name == 'category':
143153
self.data[name].cat.remove_unused_categories(inplace=True)
144154

145-
def make_groups_categorical(self):
155+
def make_groups_categorical(self, kind):
146156
category_cols = [self.groupby, self.split, self.row, self.col]
147157
for col in category_cols:
148158
if col:
149159
if self.data[col].dtype.name != 'category':
150160
self.data[col] = self.data[col].astype('category')
161+
if kind == 'count':
162+
col = self.x or self.y
163+
if self.data[col].dtype.name != 'category':
164+
self.data[col] = self.data[col].astype('category')
151165

152166
def validate_numeric(self, check_numeric):
153167
if check_numeric:
@@ -348,6 +362,7 @@ def get_fig_shape(self):
348362
return nrows, ncols
349363

350364
def get_data_for_every_plot(self):
365+
# TODO: catch keyerror for groups that dont exist
351366
rows, cols = self.get_row_col_order()
352367
if self.plot_type == 'row_only':
353368
return [(row, self.data.loc[row]) for row in rows]
@@ -362,7 +377,7 @@ def get_data_for_every_plot(self):
362377
with warnings.catch_warnings():
363378
warnings.simplefilter("ignore")
364379
data = self.data.loc[group]
365-
except KeyError:
380+
except (KeyError, TypeError):
366381
data = self.data.iloc[:0]
367382
groups.append((group, data))
368383
return groups
@@ -423,7 +438,7 @@ def get_order(self, arr, vals):
423438

424439
def reverse_order(self, order):
425440
cond1 = order == 'desc' and self.orientation == 'v'
426-
cond2 = order == 'asc' and self.orientation == 'h'
441+
cond2 = order in ('asc', None) and self.orientation == 'h'
427442
return cond1 or cond2
428443

429444
def order_xy(self, x, y):
@@ -471,7 +486,8 @@ def get_ordered_groups(self, data, specific_order, kind):
471486
order = []
472487
groups = []
473488
sort = specific_order is not None
474-
for grp, data_grp in data.groupby(getattr(self, kind), sort=sort):
489+
# TODO: Need to decide defaults for x_order, y_order etc... either None or 'asc'
490+
for grp, data_grp in data.groupby(getattr(self, kind), sort=True):
475491
order.append((grp, data_grp))
476492
groups.append(grp)
477493

@@ -535,12 +551,13 @@ def get_final_groups(self, data, split_label, row_label, col_label):
535551
s = data[col]
536552
x, y = s.index.values, s.values
537553
x, y = self.get_correct_data_order(x, y)
538-
groups.append((x, y, split_label, col, row_label, col_label))
554+
x, y = (x, y) if self.orientation == 'v' else (y, x)
555+
groups.append((x, y, col, None, row_label, col_label))
539556
else:
540557
# simple raw plot - make sure to warn when lots of data for bar/box/hist
541558
# one graph per row - OK for scatterplots and line plots
542559
x, y = self.get_correct_data_order(data[self.x], data[self.y])
543-
groups.append((x, y, None, None, row_label, col_label))
560+
groups.append((x, y, split_label, None, row_label, col_label))
544561
return groups
545562

546563
def get_final_data(self):
@@ -635,8 +652,8 @@ def add_ticklabels(self, labels, ax, delta=0):
635652
ax.set_yticks(ticks - delta)
636653
ax.set_yticklabels(labels)
637654

638-
def add_legend(self, handles=None, labels=None):
639-
if self.split:
655+
def add_legend(self, label=None, handles=None, labels=None):
656+
if label is not None:
640657
if handles is None:
641658
handles, labels = self.axs[0].get_legend_handles_labels()
642659
ncol = len(labels) // 8 + 1
@@ -664,4 +681,7 @@ def update_fig_size(self, n_splits, n_groups_per_split):
664681
height = new_size * .8 * self.fig_shape[0]
665682
width = width * self.fig_shape[1]
666683
width, height = min(width, 25), min(height, 25)
667-
self.fig.set_size_inches(width, height)
684+
self.fig.set_size_inches(width, height)
685+
686+
def add_fig_title(self):
687+
self.fig.suptitle(self.title, y=1.02)

0 commit comments

Comments
 (0)