Skip to content

Commit b1c16af

Browse files
committed
Added support for four plots; Renamed from AxesPlot to AxisPlot.
1 parent c40b69c commit b1c16af

File tree

6 files changed

+426
-78
lines changed

6 files changed

+426
-78
lines changed

README.rst

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1-
AxesPlot
1+
AxisPlot
22
========
33

4-
AxesPlot extends the functionality of Matplotlib's `imshow()` function by
5-
appending two plots to the image, at the top or bottom, and on the left or
4+
AxisPlot extends the functionality of Matplotlib's `imshow()` function by
5+
appending up to four plots to the image, at the top, bottom, left, and/or
66
right. The additional plots contain the output of operations that are performed
7-
along the two axes. For example, the screenshot below shows an AxesPlot with
8-
the mean of the image computed along the vertical axis, shown at the top, and
9-
the sum of the image computed along the horizontal axis, shown on the right.
7+
along the two axes. Plots at the top and bottom contain the output of
8+
operations performed along the vertical axis, while plots on the left and right
9+
contain the output of operations that are performed along the horizontal axis.
10+
For example, the screenshot below shows an AxisPlot with the mean of the image
11+
computed along the vertical axis, shown at the top, and the sum of the image
12+
computed along the horizontal axis, shown on the right.
1013

1114
.. class:: no-web
1215

13-
.. image:: https://github.com/jayanthc/axesplot/blob/master/examples/example.png
14-
:alt: AxesPlot screenshot
16+
.. image:: https://github.com/jayanthc/axisplot/blob/master/examples/example.png
17+
:alt: AxisPlot screenshot
1518
:height: 1088px
1619
:width: 1280px
1720
:scale: 60%
@@ -24,16 +27,19 @@ Usage
2427
.. code:: python
2528
2629
import numpy as np
27-
import axesplot
30+
import matplotlib.pyplot as plt
31+
import axisplot as ap
2832
2933
# generate some data
3034
dim_x = 512
3135
dim_y = 256
3236
x = np.linspace(0, 2 * np.pi, dim_x)
3337
X = np.random.normal(size=(dim_y, dim_x)) + np.sin(x)
34-
# create axesplot with mean along the y-axis at the top, and sum along the
38+
# create axisplot with mean along the y-axis at the top, and sum along the
3539
# x-axis on the right
36-
axesplot.AxesPlot(X, np.mean, np.sum, cmap='plasma')
40+
axisplot = ap.AxisPlot(optop=np.mean, opright=np.sum, cmap='plasma')
41+
axisplot.plot(X)
42+
plt.show()
3743
3844
Installation
3945
------------
@@ -42,21 +48,21 @@ Development mode:
4248

4349
::
4450

45-
cd <axesplot-directory>
51+
cd <axisplot-directory>
4652
pip install -e .
4753

48-
Unit Testinig
54+
Uniit Testinig
4955
------------
5056

5157
::
5258

53-
cd <axesplot-directory>
59+
cd <axisplot-directory>
5460
python -m unittest
5561

5662
License
5763
-------
5864

59-
AxesPlot is distributed under WTFPLv2.
65+
AxisPlot is distributed under WTFPLv2.
6066

6167

6268
----

axesplot/__init__.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

axisplot.ipynb

Lines changed: 269 additions & 0 deletions
Large diffs are not rendered by default.

axisplot/__init__.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
import matplotlib as mp
4+
import mpl_toolkits.axes_grid1 as mag1
5+
6+
7+
class AxisPlot:
8+
def __init__(self, optop=None, opbottom=None, opleft=None, opright=None,
9+
figsize=None, padtop=0.0, padbottom=0.0, padleft=0.0,
10+
padright=0.0, **imshowkwargs):
11+
self.optop = optop
12+
self.opbottom = opbottom
13+
self.opleft = opleft
14+
self.opright = opright
15+
self.figsize = figsize
16+
self.padtop = padtop
17+
self.padbottom = padbottom
18+
self.padleft = padleft
19+
self.padright = padright
20+
self.imshowkwargs = imshowkwargs
21+
22+
if figsize is None:
23+
figsize = mp.rcParams['figure.figsize']
24+
# figsize is (width, height)
25+
self.aspect_ratio = figsize[0] / figsize[1]
26+
27+
default_height_frac = 0.2
28+
# heights is (y, x)
29+
self.heights = [figsize[0] * default_height_frac,
30+
figsize[1] * self.aspect_ratio * default_height_frac]
31+
32+
return
33+
34+
def plot(self, X):
35+
fig, ax = plt.subplots(figsize=self.figsize)
36+
ax.imshow(X, **self.imshowkwargs)
37+
divider = mag1.make_axes_locatable(ax)
38+
39+
x = np.linspace(0, X.shape[1], X.shape[1])
40+
y = np.linspace(0, X.shape[0], X.shape[0])
41+
42+
plot_axes = [ax]
43+
44+
if self.optop:
45+
ax_top = divider.append_axes('top', self.heights[1],
46+
pad=self.padtop, sharex=ax)
47+
ax_top.xaxis.set_tick_params(labelbottom=False)
48+
ax_top.plot(x, self.optop(X, axis=0))
49+
ax_top.set_xlim(x.min(), x.max())
50+
plot_axes.append(ax_top)
51+
if self.opbottom:
52+
ax_bottom = divider.append_axes('bottom', self.heights[1],
53+
pad=self.padbottom, sharex=ax)
54+
# turn bottom labels off for the image
55+
ax.xaxis.set_tick_params(labelbottom=False)
56+
ax_bottom.xaxis.set_tick_params(labelbottom=True)
57+
ax_bottom.plot(x, self.opbottom(X, axis=0))
58+
ax_bottom.set_xlim(x.min(), x.max())
59+
plot_axes.append(ax_bottom)
60+
if self.opleft:
61+
ax_left = divider.append_axes('left', self.heights[0],
62+
pad=self.padleft, sharey=ax)
63+
# turn left labels off for the image
64+
ax.yaxis.set_tick_params(labelleft=False)
65+
ax_left.yaxis.set_tick_params(labelleft=True)
66+
ax_left.plot(self.opleft(X, axis=1), y)
67+
ax_left.set_ylim(y.min(), y.max())
68+
plot_axes.append(ax_left)
69+
if self.opright:
70+
ax_right = divider.append_axes('right', self.heights[0],
71+
pad=self.padright, sharey=ax)
72+
ax_right.yaxis.set_tick_params(labelleft=False)
73+
ax_right.plot(self.opright(X, axis=1), y)
74+
ax_right.set_ylim(y.min(), y.max())
75+
plot_axes.append(ax_right)
76+
77+
return plot_axes

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from setuptools import setup
22

3-
setup(name='axesplot',
3+
setup(name='axisplot',
44
version='1.0',
55
description='Matplotlib imshow() Functionality Extension',
6-
url='http://github.com/jayanthc/axesplot',
6+
url='http://github.com/jayanthc/axisplot',
77
author='Jayanth Chennamangalam',
88
license='WTFPLv2',
9-
packages=['axesplot'],
9+
packages=['axisplot'],
1010
install_requires=[
1111
'numpy',
1212
'matplotlib'

tests/__init__.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,68 @@
11
import unittest
22
import numpy as np
33
import matplotlib.pyplot as plt
4-
import axesplot
4+
import axisplot as ap
55

66

7-
class TestAxesPlot(unittest.TestCase):
8-
def test(self):
7+
class TestAxisPlot(unittest.TestCase):
8+
def test_default(self):
99
# generate some data
1010
dim_x = 512
1111
dim_y = 256
1212
x = np.linspace(0, 2 * np.pi, dim_x)
1313
X = np.random.normal(size=(dim_y, dim_x)) + np.sin(x)
14-
# create axesplot with mean along the y-axis at the top, and sum along
14+
# create axisplot with mean along the y-axis at the top, and sum along
1515
# the x-axis on the right
16-
axesplot.AxesPlot(X, np.mean, np.sum, cmap='plasma')
16+
axisplot = ap.AxisPlot(optop=np.mean, opright=np.sum, cmap='plasma')
17+
ax, ax_top, ax_right = axisplot.plot(X)
18+
ax.set_xlabel('x label')
19+
ax.set_ylabel('y label')
20+
ax_top.set_title('test_default')
21+
plt.show()
22+
23+
def test_all_plots(self):
24+
# generate some data
25+
dim_x = 512
26+
dim_y = 256
27+
x = np.linspace(0, 2 * np.pi, dim_x)
28+
X = np.random.normal(size=(dim_y, dim_x)) + np.sin(x)
29+
# create axisplot with mean along the y-axis at the top, sum along the
30+
# y-axis at the bottom, mean along the x-axis on the left, and sum
31+
# along the x-axis on the right
32+
axisplot = ap.AxisPlot(optop=np.mean, opbottom=np.sum, opleft=np.mean,
33+
opright=np.sum, cmap='viridis')
34+
_, ax_top, _, _, _ = axisplot.plot(X)
35+
ax_top.set_title('test_all_plots')
36+
plt.show()
37+
38+
def test_multiple_plots(self):
39+
# generate some data
40+
dim_x = 512
41+
dim_y = 256
42+
x = np.linspace(0, 2 * np.pi, dim_x)
43+
X = np.random.normal(size=(dim_y, dim_x)) + np.sin(x)
44+
# create axisplot
45+
axisplot = ap.AxisPlot(optop=np.mean, opright=np.sum)
46+
_, ax_top, _ = axisplot.plot(X)
47+
ax_top.set_title('test_multiple_plots 0')
48+
# create another axisplot with the same settings
49+
X = np.random.normal(size=(dim_y, dim_x)) + np.cos(x)
50+
_, ax_top, _ = axisplot.plot(X)
51+
ax_top.set_title('test_multiple_plots 1')
52+
plt.show()
53+
54+
def test_axislabels(self):
55+
# generate some data
56+
dim_x = 512
57+
dim_y = 256
58+
x = np.linspace(0, 2 * np.pi, dim_x)
59+
X = np.random.normal(size=(dim_y, dim_x)) + np.sin(x)
60+
# create axisplot and update some tick labels
61+
axisplot = ap.AxisPlot(optop=np.mean, opright=np.sum)
62+
ax, ax_top, ax_right = axisplot.plot(X)
63+
ticks = ax.get_xticks()
64+
labels = ax.get_xticklabels()
65+
ax.set_xticklabels(map(lambda x: '{:.1f}'.format(x), ticks / 100))
66+
ax_right.set_xticklabels([])
67+
ax_top.set_title('test_axislabels')
1768
plt.show()

0 commit comments

Comments
 (0)