Skip to content

Commit bfe77f9

Browse files
committed
Merge pull request yhat#325 from arnfred/bar
Reworked bar plot
2 parents f25156f + 7cf9048 commit bfe77f9

21 files changed

+240
-72
lines changed

ggplot/geoms/geom.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
from copy import deepcopy
44

55
import pandas as pd
6+
import numpy as np
67
from matplotlib.cbook import iterable
8+
from ggplot.utils import is_string
79

810
import ggplot.stats
911
from ggplot.utils import is_scalar_or_string
@@ -199,7 +201,7 @@ def _find_aes_and_data(self, args, kwargs):
199201

200202
for arg in args:
201203
if isinstance(arg, aes) and passed_aes:
202-
raise Execption(aes_err)
204+
raise Exception(aes_err)
203205
if isinstance(arg, aes):
204206
passed_aes = arg
205207
elif isinstance(arg, pd.DataFrame):
@@ -255,6 +257,7 @@ def _calculate_stats(self, data):
255257
data : dataframe
256258
"""
257259
self._stat._verify_aesthetics(data)
260+
self._stat._calculate_global(data)
258261
# In most cases 'x' and 'y' mappings do not and
259262
# should not influence the grouping. If this is
260263
# not the desired behaviour then the groups
@@ -274,9 +277,6 @@ def _calculate_stats(self, data):
274277
else:
275278
new_data = self._stat._calculate(data)
276279

277-
# some geoms expect a sorted x domain
278-
if 'x' in new_data:
279-
new_data.sort(columns=('x'), inplace=True)
280280
return new_data
281281

282282
def _create_aes_with_mpl_names(self):
@@ -338,3 +338,33 @@ def _get_unit_grouped_data(self, data, units):
338338
_data = data.to_dict('list')
339339
out.append(_data)
340340
return out
341+
342+
343+
def sort_by_x(self, pinfo):
344+
"""
345+
Sort the lists in pinfo according to pinfo['x']
346+
This function is useful for geom's that expect
347+
the x-values to come in sorted order
348+
"""
349+
# Remove list types from pinfo
350+
_d = {}
351+
for k in list(pinfo.keys()):
352+
if not is_string(pinfo[k]) and iterable(pinfo[k]):
353+
_d[k] = pinfo.pop(k)
354+
355+
# Sort numerically if all items can be cast
356+
try:
357+
x = list(map(np.float, _d['x']))
358+
except (ValueError, TypeError):
359+
x = _d['x']
360+
361+
# Make sure we don't try to sort something unsortable
362+
try:
363+
idx = np.argsort(x)
364+
# Put sorted lists back in pinfo
365+
for key in _d:
366+
pinfo[key] = [_d[key][i] for i in idx]
367+
except:
368+
pass
369+
return pinfo
370+

ggplot/geoms/geom_area.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import (absolute_import, division, print_function,
22
unicode_literals)
3+
34
from .geom import geom
45

56

@@ -14,4 +15,5 @@ class geom_area(geom):
1415
_units = { 'alpha', 'edgecolor', 'facecolor', 'linestyle', 'linewidth'}
1516

1617
def _plot_unit(self, pinfo, ax):
18+
pinfo = self.sort_by_x(pinfo)
1719
ax.fill_between(**pinfo)

ggplot/geoms/geom_bar.py

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
class geom_bar(geom):
1313
DEFAULT_AES = {'alpha': None, 'color': None, 'fill': '#333333',
14-
'linetype': 'solid', 'size': 1.0, 'weight': None, 'y': None}
14+
'linetype': 'solid', 'size': 1.0, 'weight': None, 'y': None, 'width' : None}
1515
REQUIRED_AES = {'x'}
1616
DEFAULT_PARAMS = {'stat': 'bin', 'position': 'stack'}
1717

@@ -20,43 +20,35 @@ class geom_bar(geom):
2020
'fill': 'color', 'color': 'edgecolor'}
2121
# NOTE: Currently, geom_bar does not support mapping
2222
# to alpha and linestyle. TODO: raise exception
23-
_units = {'alpha', 'linestyle', 'linewidth'}
23+
_units = {'edgecolor', 'color', 'alpha', 'linestyle', 'linewidth'}
2424

25-
def _sort_list_types_by_x(self, pinfo):
26-
"""
27-
Sort the lists in pinfo according to pinfo['x']
28-
"""
29-
# Remove list types from pinfo
30-
_d = {}
31-
for k in list(pinfo.keys()):
32-
if not is_string(pinfo[k]) and cbook.iterable(pinfo[k]):
33-
_d[k] = pinfo.pop(k)
3425

35-
# Sort numerically if all items can be cast
36-
try:
37-
x = list(map(np.float, _d['x']))
38-
except ValueError:
39-
x = _d['x']
40-
idx = np.argsort(x)
26+
def __init__(self, *args, **kwargs):
27+
# TODO: Change self.__class__ to geom_bar
28+
super(geom_bar, self).__init__(*args, **kwargs)
29+
self.bottom = None
30+
self.ax = None
4131

42-
# Put sorted lists back in pinfo
43-
for key in _d:
44-
pinfo[key] = [_d[key][i] for i in idx]
45-
46-
return pinfo
4732

4833
def _plot_unit(self, pinfo, ax):
4934
categorical = is_categorical(pinfo['x'])
50-
# If x is not numeric, the bins are sorted acc. to x
51-
# so the list type aesthetics must be sorted too
52-
if categorical:
53-
pinfo = self._sort_list_types_by_x(pinfo)
5435

5536
pinfo.pop('weight')
5637
x = pinfo.pop('x')
57-
width = np.array(pinfo.pop('width'))
58-
heights = pinfo.pop('y')
59-
labels = x
38+
width_elem = pinfo.pop('width')
39+
# If width is unspecified, default is an array of 1's
40+
if width_elem == None:
41+
width = np.ones(len(x))
42+
else :
43+
width = np.array(width_elem)
44+
45+
# Make sure bottom is initialized and get heights. If we are working on
46+
# a new plot (using facet_wrap or grid), then reset bottom
47+
_reset = self.bottom == None or (self.ax != None and self.ax != ax)
48+
self.bottom = np.zeros(len(x)) if _reset else self.bottom
49+
self.ax = ax
50+
heights = np.array(pinfo.pop('y'))
51+
6052

6153
# layout and spacing
6254
#
@@ -79,14 +71,14 @@ def _plot_unit(self, pinfo, ax):
7971
_spacing_factor = 0.105 # of the bin width
8072
_breaks = np.append([0], width)
8173
left = np.cumsum(_breaks[:-1])
82-
8374
_sep = width[0] * _spacing_factor
8475
left = left + _left_gap + [_sep * i for i in range(len(left))]
85-
86-
87-
ax.bar(left, heights, width, **pinfo)
76+
ax.bar(left, heights, width, bottom=self.bottom, **pinfo)
8877
ax.autoscale()
8978

9079
if categorical:
9180
ax.set_xticks(left+width/2)
9281
ax.set_xticklabels(x)
82+
83+
# Update bottom positions
84+
self.bottom = heights + self.bottom

ggplot/geoms/geom_histogram.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,4 @@
22
unicode_literals)
33
from .geom_bar import geom_bar
44

5-
class geom_histogram(geom_bar):
6-
pass
5+
geom_histogram = geom_bar

ggplot/geoms/geom_line.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def _plot_unit(self, pinfo, ax):
2828
sys.stderr.write(msg)
2929
self._warning_printed = True
3030

31+
pinfo = self.sort_by_x(pinfo)
3132
x = pinfo.pop('x')
3233
y = pinfo.pop('y')
3334
ax.plot(x, y, **pinfo)

ggplot/stats/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import (absolute_import, division, print_function,
2-
unicode_literals)
2+
unicode_literals)
33

44
from .stat_abline import stat_abline
55
from .stat_bin import stat_bin
@@ -10,6 +10,7 @@
1010
from .stat_identity import stat_identity
1111
from .stat_smooth import stat_smooth
1212
from .stat_vline import stat_vline
13+
from .stat_bar import stat_bar
1314

1415
__all__ = ['stat_abline', 'stat_bin', 'stat_bin2d', 'stat_density',
1516
'stat_function', 'stat_hline', 'stat_identity',

ggplot/stats/stat.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ def _print_warning(self, message):
4646
sys.stderr.write(message)
4747
self._warnings_printed.add(message)
4848

49+
# For some stats we need to calculate something from the entire set of data
50+
# before we work with the groups. An example is stat_bin, where we need to
51+
# know the max and min of the x-axis globally. If we don't we end up with
52+
# groups that are binned based on only the group x-axis leading to
53+
# different bin-sizes.
54+
def _calculate_global(self, data):
55+
pass
56+
4957
def _calculate(self, data):
5058
msg = "{} should implement this method."
5159
raise NotImplementedError(

ggplot/stats/stat_bar.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from __future__ import (absolute_import, division, print_function,
2+
unicode_literals)
3+
import pandas as pd
4+
5+
from .stat import stat
6+
7+
_MSG_LABELS = """There are more than 30 unique values mapped to x.
8+
If you want a histogram instead, use 'geom_histogram()'.
9+
"""
10+
11+
class stat_bar(stat):
12+
REQUIRED_AES = {'x', 'y'}
13+
DEFAULT_PARAMS = {'geom': 'bar', 'position': 'stack',
14+
'width': 0.9, 'drop': False,
15+
'origin': None, 'labels': None}
16+
17+
18+
def _calculate(self, data):
19+
# reorder x according to the labels
20+
new_data = pd.DataFrame()
21+
new_data["x"] = self.labels
22+
for column in set(data.columns) - set('x'):
23+
column_dict = dict(zip(data["x"],data[column]))
24+
default = 0 if column == "y" else data[column].values[0]
25+
new_data[column] = [column_dict.get(val, default)
26+
for val in self.labels]
27+
return new_data
28+
29+
30+
def _calculate_global(self, data):
31+
labels = self.params['labels']
32+
if labels == None:
33+
labels = sorted(set(data['x'].values))
34+
# For a lot of labels, put out a warning
35+
if len(labels) > 30:
36+
self._print_warning(_MSG_LABELS)
37+
# Check if there is a mapping
38+
self.labels = labels
39+

ggplot/stats/stat_bin.py

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
_MSG_YVALUE = """A variable was mapped to y.
1414
stat_bin sets the y value to the count of cases in each group.
1515
The mapping to y was ignored.
16-
If you want y to represent values in the data, use stat="identity".
16+
If you want y to represent values in the data, use stat="bar".
1717
"""
1818

1919
_MSG_BINWIDTH = """stat_bin: binwidth defaulted to range/30.
@@ -23,16 +23,46 @@
2323

2424
class stat_bin(stat):
2525
REQUIRED_AES = {'x'}
26-
DEFAULT_PARAMS = {'geom': 'bar', 'position': 'stack',
26+
DEFAULT_PARAMS = {'geom': 'histogram', 'position': 'stack',
2727
'width': 0.9, 'drop': False, 'right': False,
28-
'binwidth': None, 'origin': None, 'breaks': None}
28+
'binwidth': None, 'origin': None, 'breaks': None,
29+
'labels': None}
2930
CREATES = {'y', 'width'}
3031

32+
33+
def _calculate_global(self, data):
34+
# Calculate breaks if x is not categorical
35+
binwidth = self.params['binwidth']
36+
self.breaks = self.params['breaks']
37+
right = self.params['right']
38+
x = data['x'].values
39+
40+
# For categorical data we set labels and x-vals
41+
if is_categorical(x):
42+
labels = self.params['labels']
43+
if labels == None:
44+
labels = sorted(set(x))
45+
self.labels = labels
46+
self.length = len(self.labels)
47+
48+
# For non-categoriacal data we set breaks
49+
if not (is_categorical(x) or self.breaks):
50+
# Check that x is numerical
51+
if not cbook.is_numlike(x[0]):
52+
raise GgplotError("Cannot recognise the type of x")
53+
if binwidth is None:
54+
_bin_count = 30
55+
self._print_warning(_MSG_BINWIDTH)
56+
else:
57+
_bin_count = int(np.ceil(np.ptp(x))) / binwidth
58+
_, self.breaks = pd.cut(x, bins=_bin_count, labels=False,
59+
right=right, retbins=True)
60+
self.length = len(self.breaks)
61+
62+
3163
def _calculate(self, data):
3264
x = data.pop('x')
33-
breaks = self.params['breaks']
3465
right = self.params['right']
35-
binwidth = self.params['binwidth']
3666

3767
# y values are not needed
3868
try:
@@ -50,26 +80,16 @@ def _calculate(self, data):
5080
else:
5181
weights = make_iterable_ntimes(weights, len(x))
5282

53-
categorical = is_categorical(x.values)
54-
if categorical:
83+
if is_categorical(x.values):
5584
x_assignments = x
56-
x = sorted(set(x))
57-
width = make_iterable_ntimes(self.params['width'], len(x))
85+
x = self.labels
86+
width = make_iterable_ntimes(self.params['width'], self.length)
5887
elif cbook.is_numlike(x.iloc[0]):
59-
if breaks is None and binwidth is None:
60-
_bin_count = 30
61-
self._print_warning(_MSG_BINWIDTH)
62-
if binwidth:
63-
_bin_count = int(np.ceil(np.ptp(x))) / binwidth
64-
65-
# Breaks have a higher precedence and,
66-
# pandas accepts either the breaks or the number of bins
67-
_bins_info = breaks or _bin_count
68-
x_assignments, breaks = pd.cut(x, bins=_bins_info, labels=False,
69-
right=right, retbins=True)
70-
width = np.diff(breaks)
71-
x = [breaks[i] + width[i] / 2
72-
for i in range(len(breaks)-1)]
88+
x_assignments = pd.cut(x, bins=self.breaks, labels=False,
89+
right=right)
90+
width = np.diff(self.breaks)
91+
x = [self.breaks[i] + width[i] / 2
92+
for i in range(len(self.breaks)-1)]
7393
else:
7494
raise GgplotError("Cannot recognise the type of x")
7595

@@ -86,13 +106,14 @@ def _calculate(self, data):
86106
# For numerical x values, empty bins get have no value
87107
# in the computed frequency table. We need to add the zeros and
88108
# since frequency table is a Series object, we need to keep it ordered
89-
if len(_wfreq_table) < len(x):
90-
empty_bins = set(range(len(x))) - set(x_assignments)
91-
_wfreq_table = _wfreq_table.to_dict()
92-
for _b in empty_bins:
93-
_wfreq_table[_b] = 0
94-
_wfreq_table = pd.Series(_wfreq_table)
95-
_wfreq_table = _wfreq_table.sort_index()
109+
try:
110+
empty_bins = set(self.labels) - set(x_assignments)
111+
except:
112+
empty_bins = set(range(len(width))) - set(x_assignments)
113+
_wfreq_table = _wfreq_table.to_dict()
114+
for _b in empty_bins:
115+
_wfreq_table[_b] = 0
116+
_wfreq_table = pd.Series(_wfreq_table).sort_index()
96117

97118
y = list(_wfreq_table)
98119
new_data = pd.DataFrame({'x': x, 'y': y, 'width': width})

ggplot/tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def teardown_package():
3131
'ggplot.tests.test_stat',
3232
'ggplot.tests.test_stat_calculate_methods',
3333
'ggplot.tests.test_geom_rect',
34+
'ggplot.tests.test_geom_bar',
3435
'ggplot.tests.test_qplot',
3536
'ggplot.tests.test_geom_lines',
3637
'ggplot.tests.test_faceting',

0 commit comments

Comments
 (0)