Skip to content

ENH: add figsize argument to DataFrame and Series hist methods #3842

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 12, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions RELEASE.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ pandas 0.11.1
spurious plots from showing up.
- Added Faq section on repr display options, to help users customize their setup.
- ``where`` operations that result in block splitting are much faster (GH3733_)
- Series and DataFrame hist methods now take a ``figsize`` argument (GH3834_)

**API Changes**

Expand Down Expand Up @@ -312,6 +313,8 @@ pandas 0.11.1
.. _GH3726: https://github.com/pydata/pandas/issues/3726
.. _GH3795: https://github.com/pydata/pandas/issues/3795
.. _GH3814: https://github.com/pydata/pandas/issues/3814
.. _GH3834: https://github.com/pydata/pandas/issues/3834


pandas 0.11.0
=============
Expand Down
3 changes: 3 additions & 0 deletions doc/source/v0.11.1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ Enhancements

dff.groupby('B').filter(lambda x: len(x) > 2, dropna=False)

- Series and DataFrame hist methods now take a ``figsize`` argument (GH3834_)


Bug Fixes
~~~~~~~~~
Expand Down Expand Up @@ -396,3 +398,4 @@ on GitHub for a complete list.
.. _GH3741: https://github.com/pydata/pandas/issues/3741
.. _GH3726: https://github.com/pydata/pandas/issues/3726
.. _GH3425: https://github.com/pydata/pandas/issues/3425
.. _GH3834: https://github.com/pydata/pandas/issues/3834
58 changes: 35 additions & 23 deletions pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pandas import Series, DataFrame, MultiIndex, PeriodIndex, date_range
import pandas.util.testing as tm
from pandas.util.testing import ensure_clean
from pandas.core.config import set_option,get_option,config_prefix
from pandas.core.config import set_option

import numpy as np

Expand All @@ -28,11 +28,6 @@ class TestSeriesPlots(unittest.TestCase):

@classmethod
def setUpClass(cls):
import sys

# if 'IPython' in sys.modules:
# raise nose.SkipTest

try:
import matplotlib as mpl
mpl.use('Agg', warn=False)
Expand Down Expand Up @@ -150,9 +145,16 @@ def test_irregular_datetime(self):
def test_hist(self):
_check_plot_works(self.ts.hist)
_check_plot_works(self.ts.hist, grid=False)

_check_plot_works(self.ts.hist, figsize=(8, 10))
_check_plot_works(self.ts.hist, by=self.ts.index.month)

def test_plot_fails_when_ax_differs_from_figure(self):
from pylab import figure
fig1 = figure()
fig2 = figure()
ax1 = fig1.add_subplot(111)
self.assertRaises(AssertionError, self.ts.hist, ax=ax1, figure=fig2)

@slow
def test_kde(self):
_skip_if_no_scipy()
Expand Down Expand Up @@ -258,7 +260,8 @@ def test_plot(self):
(u'\u03b4', 6),
(u'\u03b4', 7)], names=['i0', 'i1'])
columns = MultiIndex.from_tuples([('bar', u'\u0394'),
('bar', u'\u0395')], names=['c0', 'c1'])
('bar', u'\u0395')], names=['c0',
'c1'])
df = DataFrame(np.random.randint(0, 10, (8, 2)),
columns=columns,
index=index)
Expand All @@ -269,9 +272,9 @@ def test_nonnumeric_exclude(self):
import matplotlib.pyplot as plt
plt.close('all')

df = DataFrame({'A': ["x", "y", "z"], 'B': [1,2,3]})
df = DataFrame({'A': ["x", "y", "z"], 'B': [1, 2, 3]})
ax = df.plot()
self.assert_(len(ax.get_lines()) == 1) #B was plotted
self.assert_(len(ax.get_lines()) == 1) # B was plotted

@slow
def test_label(self):
Expand Down Expand Up @@ -434,21 +437,24 @@ def test_bar_center(self):
ax = df.plot(kind='bar', grid=True)
self.assertEqual(ax.xaxis.get_ticklocs()[0],
ax.patches[0].get_x() + ax.patches[0].get_width())

@slow
def test_bar_log(self):
# GH3254, GH3298 matplotlib/matplotlib#1882, #1892
# regressions in 1.2.1

df = DataFrame({'A': [3] * 5, 'B': range(1,6)}, index=range(5))
ax = df.plot(kind='bar', grid=True,log=True)
self.assertEqual(ax.yaxis.get_ticklocs()[0],1.0)
df = DataFrame({'A': [3] * 5, 'B': range(1, 6)}, index=range(5))
ax = df.plot(kind='bar', grid=True, log=True)
self.assertEqual(ax.yaxis.get_ticklocs()[0], 1.0)

p1 = Series([200,500]).plot(log=True,kind='bar')
p2 = DataFrame([Series([200,300]),Series([300,500])]).plot(log=True,kind='bar',subplots=True)
p1 = Series([200, 500]).plot(log=True, kind='bar')
p2 = DataFrame([Series([200, 300]),
Series([300, 500])]).plot(log=True, kind='bar',
subplots=True)

(p1.yaxis.get_ticklocs() == np.array([ 0.625, 1.625]))
(p2[0].yaxis.get_ticklocs() == np.array([ 1., 10., 100., 1000.])).all()
(p2[1].yaxis.get_ticklocs() == np.array([ 1., 10., 100., 1000.])).all()
(p1.yaxis.get_ticklocs() == np.array([0.625, 1.625]))
(p2[0].yaxis.get_ticklocs() == np.array([1., 10., 100., 1000.])).all()
(p2[1].yaxis.get_ticklocs() == np.array([1., 10., 100., 1000.])).all()

@slow
def test_boxplot(self):
Expand Down Expand Up @@ -508,6 +514,9 @@ def test_hist(self):
# make sure sharex, sharey is handled
_check_plot_works(df.hist, sharex=True, sharey=True)

# handle figsize arg
_check_plot_works(df.hist, figsize=(8, 10))

# make sure xlabelsize and xrot are handled
ser = df[0]
xf, yf = 20, 20
Expand Down Expand Up @@ -727,6 +736,7 @@ def test_invalid_kind(self):
df = DataFrame(np.random.randn(10, 2))
self.assertRaises(ValueError, df.plot, kind='aasdf')


class TestDataFrameGroupByPlots(unittest.TestCase):

@classmethod
Expand Down Expand Up @@ -786,10 +796,10 @@ def test_time_series_plot_color_with_empty_kwargs(self):

plt.close('all')
for i in range(3):
ax = Series(np.arange(12) + 1, index=date_range(
'1/1/2000', periods=12)).plot()
ax = Series(np.arange(12) + 1, index=date_range('1/1/2000',
periods=12)).plot()

line_colors = [ l.get_color() for l in ax.get_lines() ]
line_colors = [l.get_color() for l in ax.get_lines()]
self.assert_(line_colors == ['b', 'g', 'r'])

@slow
Expand Down Expand Up @@ -829,7 +839,6 @@ def test_grouped_hist(self):
self.assertRaises(AttributeError, plotting.grouped_hist, df.A,
by=df.C, foo='bar')


def test_option_mpl_style(self):
# just a sanity check
try:
Expand All @@ -845,14 +854,15 @@ def test_option_mpl_style(self):
except ValueError:
pass


def _check_plot_works(f, *args, **kwargs):
import matplotlib.pyplot as plt

fig = plt.gcf()
plt.clf()
ax = fig.add_subplot(211)
ret = f(*args, **kwargs)
assert(ret is not None) # do something more intelligent
assert ret is not None # do something more intelligent

ax = fig.add_subplot(212)
try:
Expand All @@ -865,10 +875,12 @@ def _check_plot_works(f, *args, **kwargs):
with ensure_clean() as path:
plt.savefig(path)


def curpath():
pth, _ = os.path.split(os.path.abspath(__file__))
return pth


if __name__ == '__main__':
nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'],
exit=False)
42 changes: 26 additions & 16 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,9 +658,9 @@ def r(h):
return ax


def grouped_hist(data, column=None, by=None, ax=None, bins=50,
figsize=None, layout=None, sharex=False, sharey=False,
rot=90, grid=True, **kwargs):
def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None,
layout=None, sharex=False, sharey=False, rot=90, grid=True,
**kwargs):
"""
Grouped histogram

Expand Down Expand Up @@ -1839,10 +1839,9 @@ def plot_group(group, ax):
return fig


def hist_frame(
data, column=None, by=None, grid=True, xlabelsize=None, xrot=None,
ylabelsize=None, yrot=None, ax=None,
sharex=False, sharey=False, **kwds):
def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None,
xrot=None, ylabelsize=None, yrot=None, ax=None, sharex=False,
sharey=False, figsize=None, **kwds):
"""
Draw Histogram the DataFrame's series using matplotlib / pylab.

Expand All @@ -1866,17 +1865,20 @@ def hist_frame(
ax : matplotlib axes object, default None
sharex : bool, if True, the X axis will be shared amongst all subplots.
sharey : bool, if True, the Y axis will be shared amongst all subplots.
figsize : tuple
The size of the figure to create in inches by default
kwds : other plotting keyword arguments
To be passed to hist function
"""
if column is not None:
if not isinstance(column, (list, np.ndarray)):
column = [column]
data = data.ix[:, column]
data = data[column]

if by is not None:

axes = grouped_hist(data, by=by, ax=ax, grid=grid, **kwds)
axes = grouped_hist(data, by=by, ax=ax, grid=grid, figsize=figsize,
**kwds)

for ax in axes.ravel():
if xlabelsize is not None:
Expand All @@ -1898,11 +1900,11 @@ def hist_frame(
rows += 1
else:
cols += 1
_, axes = _subplots(nrows=rows, ncols=cols, ax=ax, squeeze=False,
sharex=sharex, sharey=sharey)
fig, axes = _subplots(nrows=rows, ncols=cols, ax=ax, squeeze=False,
sharex=sharex, sharey=sharey, figsize=figsize)

for i, col in enumerate(com._try_sort(data.columns)):
ax = axes[i / cols][i % cols]
ax = axes[i / cols, i % cols]
ax.xaxis.set_visible(True)
ax.yaxis.set_visible(True)
ax.hist(data[col].dropna().values, **kwds)
Expand All @@ -1922,13 +1924,13 @@ def hist_frame(
ax = axes[j / cols, j % cols]
ax.set_visible(False)

ax.get_figure().subplots_adjust(wspace=0.3, hspace=0.3)
fig.subplots_adjust(wspace=0.3, hspace=0.3)

return axes


def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None,
xrot=None, ylabelsize=None, yrot=None, **kwds):
xrot=None, ylabelsize=None, yrot=None, figsize=None, **kwds):
"""
Draw histogram of the input series using matplotlib

Expand All @@ -1948,6 +1950,8 @@ def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None,
If specified changes the y-axis label size
yrot : float, default None
rotation of y axis labels
figsize : tuple, default None
figure size in inches by default
kwds : keywords
To be passed to the actual plotting function

Expand All @@ -1958,16 +1962,22 @@ def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None,
"""
import matplotlib.pyplot as plt

fig = kwds.setdefault('figure', plt.figure(figsize=figsize))

if by is None:
if ax is None:
ax = plt.gca()
ax = fig.add_subplot(111)
else:
if ax.get_figure() != fig:
raise AssertionError('passed axis not bound to passed figure')
values = self.dropna().values

ax.hist(values, **kwds)
ax.grid(grid)
axes = np.array([ax])
else:
axes = grouped_hist(self, by=by, ax=ax, grid=grid, **kwds)
axes = grouped_hist(self, by=by, ax=ax, grid=grid, figsize=figsize,
**kwds)

for ax in axes.ravel():
if xlabelsize is not None:
Expand Down