Skip to content

Commit 455da43

Browse files
committed
examples and better spines
1 parent 3e2025d commit 455da43

File tree

3 files changed

+193
-33
lines changed

3 files changed

+193
-33
lines changed

src/Graphics/Matplotlib.hs

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,19 @@
5252
-- Right now there's no easy way to bind to an option other than the last one
5353
-- unless you want to pass options in as parameters.
5454
--
55-
-- TODO The generated Python code should follow some invariants. It must maintain the
55+
-- The generated Python code should follow some invariants. It must maintain the
5656
-- current figure in "fig", all available axes in "axes", and the current axis
57-
-- in "ax".
57+
-- in "ax". Plotting commands should use the current axis, never the plot
58+
-- itself; the two APIs are almost identical. When creating low-level bindings
59+
-- one must remember to call "plot.sci" to set the current image when plotting a
60+
-- graph. The current spine of the axes that's being manipulated is in "spine".
5861
-----------------------------------------------------------------------------
5962

6063
module Graphics.Matplotlib
6164
( module Graphics.Matplotlib
6265
-- * Creating custom plots and applying options
63-
, Matplotlib(), Option(),(@@), (%), o1, o2, (##), (#), mp, def, readData,
64-
str, raw, lit, updateAxes, updateFigure)
66+
, Matplotlib(), Option(),(@@), (%), o1, o2, (##), (#), mp, def, readData
67+
, str, raw, lit, updateAxes, updateFigure, mapLinear)
6568
where
6669
import Data.List
6770
import Data.Aeson
@@ -101,20 +104,21 @@ histogram values bins = readData [values] % dataHistogram 0 bins
101104

102105
-- | Plot a 2D histogram for the given values with 'bins'
103106
histogram2D x y = readData [x,y] %
104-
mp # "plot.hist2d(data[0], data[1]" ## ")"
107+
mp # "plot.sci(ax.hist2d(data[0], data[1]" ## ")[-1])"
105108

106109
-- | Plot the given values as a scatter plot
107110
scatter :: (ToJSON t1, ToJSON t) => t1 -> t -> Matplotlib
108111
scatter x y = readData (x, y)
109-
% mp # "ax.scatter(data[0], data[1]" ## ")"
112+
% mp # "plot.sci(ax.scatter(data[0], data[1]" ## "))"
110113

111114
-- | Plot a line
112115
line :: (ToJSON t1, ToJSON t) => t1 -> t -> Matplotlib
113116
line x y = plot x y `def` [o1 "-"]
114117

115118
-- | Like 'plot' but takes an error bar value per point
116-
errorbar xs ys errs = readData (xs, ys, errs)
117-
% mp # "ax.errorbar(data[0], data[1], yerr=data[2]" ## ")"
119+
-- errorbar :: (ToJSON x, ToJSON y, ToJSON xs, ToJSON ys) => x -> y -> Maybe xs -> Maybe ys -> Matplotlib
120+
errorbar xs ys xerrs yerrs = readData (xs, ys, xerrs, yerrs)
121+
% mp # "ax.errorbar(data[0], data[1], xerr=data[2], yerr=data[3]" ## ")"
118122

119123
-- | Plot a line given a function that will be executed for each element of
120124
-- given list. The list provides the x values, the function the y values.
@@ -165,12 +169,12 @@ line1 y = line [0..length y] y
165169
-- | Plot a matrix
166170
matShow :: ToJSON a => a -> Matplotlib
167171
matShow d = readData d
168-
% (mp # "plot.matshow(data" ## ")")
172+
% (mp # "plot.sci(ax.matshow(data" ## "))")
169173

170174
-- | Plot a matrix
171175
pcolor :: ToJSON a => a -> Matplotlib
172176
pcolor d = readData d
173-
% (mp # "plot.pcolor(np.array(data)" ## ")")
177+
% (mp # "plot.sci(ax.pcolor(np.array(data)" ## "))")
174178

175179
-- | Plot a KDE of the given functions; a good bandwith will be chosen automatically
176180
density :: [Double] -> Maybe (Double, Double) -> Matplotlib
@@ -197,7 +201,7 @@ setUnicode b = mp # "matplotlib.rcParams['text.latex.unicode'] = " # b
197201

198202
-- | Plot the 'a' and 'b' entries of the data object
199203
dataPlot :: (MplotValue val, MplotValue val1) => val1 -> val -> Matplotlib
200-
dataPlot a b = mp # "p = plot.plot(data[" # a # "], data[" # b # "]" ## ")"
204+
dataPlot a b = mp # "p = ax.plot(data[" # a # "], data[" # b # "]" ## ")"
201205

202206
-- | Plot the Haskell objects 'x' and 'y' as a line
203207
plot :: (ToJSON t, ToJSON t1) => t1 -> t -> Matplotlib
@@ -216,7 +220,7 @@ dateLine x y xunit (yearStart, monthStart, dayStart) =
216220

217221
-- | Create a histogram for the 'a' entry of the data array
218222
dataHistogram :: (MplotValue val1, MplotValue val) => val1 -> val -> Matplotlib
219-
dataHistogram a bins = mp # "plot.hist(data[" # a # "]," # bins ## ")"
223+
dataHistogram a bins = mp # "ax.hist(data[" # a # "]," # bins ## ")"
220224

221225
-- | Create a scatter plot accessing the given fields of the data array
222226
dataScatter :: (MplotValue val1, MplotValue val) => val1 -> val -> Matplotlib
@@ -309,8 +313,14 @@ acorr x = readData x % mp # "ax.acorr(data" ## ")"
309313
-- | Plot text at a specified location
310314
text x y s = mp # "ax.text(" # x # "," # y # "," # raw s ## ")"
311315

316+
figText x y s = mp # "plot.figtext(" # x # "," # y # "," # raw s ## ")"
317+
312318
-- * Layout, axes, and legends
313319

320+
-- | Square up the aspect ratio of a plot.
321+
setAspect :: Matplotlib
322+
setAspect = mp # "ax.set_aspect(" ## ")"
323+
314324
-- | Square up the aspect ratio of a plot.
315325
squareAxes :: Matplotlib
316326
squareAxes = mp # "ax.set_aspect('equal')"
@@ -353,14 +363,6 @@ legend = mp # "ax.legend(" ## ")"
353363
-- TODO This refers to the plot and not an axis. Might cause trouble with subplots
354364
colorbar = mp # "plot.colorbar(" ## ")"
355365

356-
-- | Set the spacing of ticks on the x axis
357-
axisXTickSpacing :: (MplotValue val1, MplotValue val) => val1 -> val -> Matplotlib
358-
axisXTickSpacing nr width = mp # "ax.set_xticks(np.arange(" # nr # ")+" # width ## ")"
359-
360-
-- | Set the labels on the x axis
361-
axisXTickLabels :: MplotValue val => val -> Matplotlib
362-
axisXTickLabels labels = mp # "ax.set_xticklabels( (" # labels # ") " ## " )"
363-
364366
-- | Add a title
365367
title :: String -> Matplotlib
366368
title s = mp # "ax.set_title(" # raw s ## ")"
@@ -397,22 +399,70 @@ zLabel label = mp # "ax.set_zlabel(" # raw label ## ")"
397399

398400
setSizeInches w h = mp # "fig.set_size_inches(" # w # "," # h # ", forward=True)"
399401

402+
tightLayout = mp # "fig.tight_layout()"
403+
404+
xkcd = mp # "plot.xkcd()"
405+
406+
-- * Ticks
407+
408+
xticks l = mp # "ax.set_xticks(" # l # ")"
409+
yticks l = mp # "ax.set_yticks(" # l # ")"
410+
zticks l = mp # "ax.set_zticks(" # l # ")"
411+
412+
xtickLabels l = mp # "ax.set_xticklabels(" # l # ")"
413+
ytickLabels l = mp # "ax.set_yticklabels(" # l # ")"
414+
ztickLabels l = mp # "ax.set_zticklabels(" # l # ")"
415+
416+
-- | Set the spacing of ticks on the x axis
417+
axisXTickSpacing :: (MplotValue val1, MplotValue val) => val1 -> val -> Matplotlib
418+
axisXTickSpacing nr width = mp # "ax.set_xticks(np.arange(" # nr # ")+" # width ## ")"
419+
420+
-- | Set the labels on the x axis
421+
axisXTickLabels :: MplotValue val => val -> Matplotlib
422+
axisXTickLabels labels = mp # "ax.set_xticklabels( (" # labels # ") " ## " )"
423+
424+
-- | Set the spacing of ticks on the y axis
425+
axisYTickSpacing :: (MplotValue val1, MplotValue val) => val1 -> val -> Matplotlib
426+
axisYTickSpacing nr width = mp # "ax.set_yticks(np.arange(" # nr # ")+" # width ## ")"
427+
428+
-- | Set the labels on the y axis
429+
axisYTickLabels :: MplotValue val => val -> Matplotlib
430+
axisYTickLabels labels = mp # "ax.set_yticklabels( (" # labels # ") " ## " )"
431+
432+
axisXTicksPosition p = mp # "ax.xaxis.set_ticks_position('" # p # "')"
433+
axisYTicksPosition p = mp # "ax.yaxis.set_ticks_position('" # p # "')"
434+
435+
-- * Spines
436+
437+
spine s = mp # "spine = ax.spines['" # s # "']"
438+
439+
spineSetBounds l h = mp # "spine.set_bounds(" # l # "," # h # ")"
440+
441+
spineSetVisible b = mp # "spine.set_visible(" # b # ")"
442+
443+
spineSetPosition s n = mp # "spine.set_position((" # s # "," # n # "))"
444+
400445
-- * Subplots
401446

447+
setAx = mp # "plot.sca(ax) "
448+
402449
-- | Create a subplot with the coordinates (r,c,f)
403-
addSubplot r c f = mp # "ax = plot.gcf().add_subplot(" # r # c # f ## ")" % updateAxes
450+
addSubplot r c f = mp # "ax = plot.gcf().add_subplot(" # r # c # f ## ")" % updateAxes % setAx
404451

405452
-- | Access a subplot with the coordinates (r,c,f)
406-
getSubplot r c f = mp # "ax = plot.subplot(" # r # "," # c # "," # f ## ")" % updateAxes
453+
getSubplot r c f = mp # "ax = plot.subplot(" # r # "," # c # "," # f ## ")" % updateAxes % setAx
407454

408455
-- | Creates subplots and stores them in an internal variable
409456
subplots = mp # "fig, axes = plot.subplots(" ## ")"
457+
% mp # "axes = np.asarray(axes)"
458+
% mp # "axes = axes.flatten()"
459+
% updateAxes % setAx
410460

411461
-- | Access a subplot
412-
setSubplot s = mp # "ax = axes[" # s # "]"
462+
setSubplot s = mp # "ax = axes[" # s # "]" % setAx
413463

414464
-- | Add axes to a figure
415-
axes = mp # "ax = plot.axes(" ## ")" % updateAxes
465+
axes = mp # "ax = plot.axes(" ## ")" % updateAxes % setAx
416466

417467
-- | Creates a new figure with the given id. If the Id is already in use it
418468
-- switches to that figure.

src/Graphics/Matplotlib/Internal.hs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,29 @@ instance MplotValue [Int] where
164164
toPython s = "[" ++ f s ++ "]"
165165
where f [] = ""
166166
f (x:xs) = toPython x ++ "," ++ f xs
167+
instance MplotValue [R] where
168+
toPython s = "[" ++ f s ++ "]"
169+
where f [] = ""
170+
f (x:xs) = toPython x ++ "," ++ f xs
171+
instance MplotValue [S] where
172+
toPython s = "[" ++ f s ++ "]"
173+
where f [] = ""
174+
f (x:xs) = toPython x ++ "," ++ f xs
175+
instance MplotValue [L] where
176+
toPython s = "[" ++ f s ++ "]"
177+
where f [] = ""
178+
f (x:xs) = toPython x ++ "," ++ f xs
167179
instance MplotValue Bool where
168180
toPython s = show s
169181
instance (MplotValue x) => MplotValue (x, x) where
170-
toPython (n, v) = toPython n ++ " = " ++ toPython v
182+
toPython (k, v) = "(" ++ toPython k ++ ", " ++ toPython v ++ ")"
171183
instance (MplotValue (x, y)) => MplotValue [(x, y)] where
172-
toPython [] = ""
173-
toPython (x:xs) = toPython x ++ ", " ++ toPython xs
184+
toPython s = "[" ++ f s ++ "]"
185+
where f [] = ""
186+
f (x:xs) = toPython x ++ "," ++ f xs
187+
instance MplotValue x => MplotValue (Maybe x) where
188+
toPython Nothing = "None"
189+
toPython (Just x) = toPython x
174190

175191
default (Integer, Int, Double)
176192

test/Spec.hs

Lines changed: 100 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
{-# language ExtendedDefaultRules, ScopedTypeVariables, QuasiQuotes #-}
1+
{-# language ExtendedDefaultRules, ScopedTypeVariables, QuasiQuotes, ParallelListComp #-}
22

33
import Test.Tasty
44
import Test.Tasty.Runners
@@ -133,7 +133,7 @@ testPlotGolden name fn =
133133
case (stderr, reads stderr) of
134134
("inf", _) -> return Nothing
135135
(_, [(x :: Double, _)]) ->
136-
if x < 30 then
136+
if x < 35 then
137137
return $ Just $ "Images very different; PSNR too low " ++ show x else
138138
return Nothing)))
139139
(BS.writeFile ref))
@@ -171,6 +171,11 @@ basicTests f = testGroup "Basic tests"
171171
, f "eventplot" meventplot
172172
, f "errorbar" merrorbar
173173
, f "scatterhist" mscatterHist
174+
, f "histMulti" mhistMulti
175+
, f "spines" mspines
176+
, f "hists" mhists
177+
, f "hinton" mhinton
178+
, f "integral" mintegral
174179
]
175180

176181
fragileTests f = testGroup "Fragile tests"
@@ -252,6 +257,7 @@ mlegend = plotMapLinear (\x -> x ** 2) 0 1 100 @@ [o2 "label" "x^2"]
252257

253258
-- | http://matplotlib.org/examples/pylab_examples/hist2d_log_demo.html
254259
mhist2DLog = histogram2D x y @@ [o2 "bins" 40, o2 "norm" $ lit "mcolors.LogNorm()"]
260+
% setAx
255261
% colorbar
256262
where (x:y:_) = chunksOf 10000 normals
257263

@@ -262,10 +268,10 @@ meventplot = plot xs ys
262268
where xs = sort $ take 10 uniforms
263269
ys = map (\x -> x ** 2) xs
264270

265-
merrorbar = errorbar xs ys errs @@ [o2 "errorevery" 2]
271+
merrorbar = errorbar xs ys (Nothing :: Maybe [Double]) (Just errs) @@ [o2 "errorevery" 2]
266272
where xs = [0.1,0.2..4]
267273
ys = map (\x -> exp $ -x) xs
268-
errs = map (\x -> 0.1 + 0.1 * sqrt x) xs
274+
errs = [map (\x -> 0.1 + 0.1 * sqrt x) xs, map (\x -> 0.1 + 0.1 * sqrt x) ys]
269275

270276
mboxplot = subplots @@ [o2 "ncols" 2, o2 "sharey" True]
271277
% setSubplot "0"
@@ -280,8 +286,7 @@ mviolinplot = subplots @@ [o2 "ncols" 2, o2 "sharey" True]
280286
% violinplot (take 3 $ chunksOf 100 $ map (* 2) $ normals) @@ [o2 "showmeans" True, o2 "showmedians" True, o2 "vert" False]
281287

282288
-- | http://matplotlib.org/examples/pylab_examples/scatter_hist.html
283-
mscatterHist = readData ()
284-
% figure 0
289+
mscatterHist = figure 0
285290
% setSizeInches 8 8
286291
-- The scatter plot
287292
% axes @@ [o1 ([left, bottom', width, height] :: [Double])]
@@ -309,3 +314,92 @@ mscatterHist = readData ()
309314
xymax = maximum [maximum $ map abs x, maximum $ map abs y]
310315
lim = ((fromIntegral $ round $ xymax / binwidth) + 1) * binwidth
311316
bins = [-lim,-lim+binwidth..(lim + binwidth)]
317+
318+
mhistMulti = subplots @@ [o2 "nrows" 2, o2 "ncols" 2]
319+
% setSubplot 0
320+
% histogram x nrBins @@ [o2 "normed" 1, o2 "histtype" "bar", o2 "color" ["red", "tan", "lime"], o2 "label" ["red", "tan", "lime"]]
321+
% legend @@ [o2 "prop" $ lit "{'size': 10}"]
322+
% title "bars with legend"
323+
% setSubplot 1
324+
% histogram x nrBins @@ [o2 "normed" 1, o2 "histtype" "bar", o2 "stacked" True]
325+
% title "stacked bar"
326+
% setSubplot 2
327+
% histogram x nrBins @@ [o2 "histtype" "step", o2 "stacked" True, o2 "fill" False]
328+
% title "stacked bar"
329+
% setSubplot 3
330+
% histogram (map (\x -> take x normals) [2000, 5000, 10000]) nrBins @@ [o2 "histtype" "bar"]
331+
% title "different sample sizes"
332+
% tightLayout
333+
where nrBins = 10
334+
x = take 3 $ chunksOf 1000 $ normals
335+
336+
mspines = plot x y @@ [o1 "k--"]
337+
% plot x y' @@ [o1 "ro"]
338+
% xlim 0 (2 * pi)
339+
% xticks [0 :: Double, pi, 2*pi]
340+
% xtickLabels (map raw ["0", "$\\pi$", "2$\\pi$"])
341+
% ylim (-1.5) 1.5
342+
% yticks [-1 :: Double, 0, 1]
343+
% spine "left"
344+
% spineSetBounds (-1) 1
345+
% spine "right"
346+
% spineSetVisible False
347+
% spine "top"
348+
% spineSetVisible False
349+
% axisYTicksPosition "left"
350+
% axisXTicksPosition "bottom"
351+
where x = mapLinear (\x -> x) 0 (2 * pi) 50
352+
y = map sin x
353+
y' = zipWith (\a b -> a + 0.1*b) y normals
354+
355+
mhists = h 10 1.5
356+
% h 4 1
357+
% h 15 2
358+
% h 6 0.5
359+
where ns mu var = map (\x -> mu + x * var) $ take 1000 normals
360+
h mu var = histogram (ns mu var) 25 @@ [o2 "histtype" "stepfilled"
361+
,o2 "alpha" 0.8
362+
,o2 "normed" True]
363+
364+
mhinton = mp # "ax.patch.set_facecolor('gray')"
365+
% setAspect @@ [o1 "equal", o1 "box"]
366+
% mp # "ax.xaxis.set_major_locator(plot.NullLocator())"
367+
% mp # "ax.yaxis.set_major_locator(plot.NullLocator())"
368+
% foldl (\a (x,y,w) -> a % f x y w) mp m
369+
% mp # "ax.autoscale_view()"
370+
% mp # "ax.invert_yaxis()"
371+
where m = [ (x,y,w) | x <- [0..19], y <- [0..19] | w <- (map (\x -> x - 0.5) normals) ]
372+
maxWeight = maximum $ map (\(_,_,v) -> abs v) m
373+
f x y w = mp # "ax.add_patch(plot.Rectangle("
374+
# "[" # (x - size / 2) # "," # (y - size / 2) # "]"
375+
# ", " # size # ", " # size
376+
# ", facecolor='" # color # "', edgecolor='" # color # "'))"
377+
where color = if w > 0 then "white" else "black"
378+
size = sqrt $ abs w / maxWeight
379+
380+
mintegral = subplots
381+
% plot x y @@ [o1 "r", o2 "linewidth" 2]
382+
% ylim 0 (maximum y)
383+
% mp # "ax.add_patch(plot.Polygon(" # ([(a, 0)] ++ zip ix iy ++ [(b,0)]) ## "))"
384+
@@ [o2 "facecolor" "0.9", o2 "edgecolor" "0.5"]
385+
% text (0.5 * (a + b)) 30 [r|$\int_a^b f(x)\mathrm{d}x$|]
386+
@@ [o2 "horizontalalignment" "center", o2 "fontsize" 20]
387+
% figText 0.9 0.05 "$x$"
388+
% figText 0.1 0.9 "$y$"
389+
% spine "right"
390+
% spineSetVisible False
391+
% spine "top"
392+
% spineSetVisible False
393+
% axisXTicksPosition "bottom"
394+
% xticks (a, b)
395+
% xtickLabels (raw "$a$", raw "$b$")
396+
% yticks ([] :: [Double])
397+
where func x = (x - 3) * (x - 5) * (x - 7) + 85
398+
-- integral limits
399+
a = 2 :: Double
400+
b = 9 :: Double
401+
(x :: [Double]) = mapLinear (\x -> x) 0 10 100
402+
y = map func x
403+
-- shaded region
404+
(ix :: [Double]) = mapLinear (\x -> x) a b 100
405+
iy = map func ix

0 commit comments

Comments
 (0)