Skip to content

Commit 983834c

Browse files
committed
benchmarks; first MNIST demo
1 parent 8bca760 commit 983834c

File tree

9 files changed

+189
-34
lines changed

9 files changed

+189
-34
lines changed

benchmark/benchmark.hs

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
{-# LANGUAGE DataKinds #-}
2+
{-# LANGUAGE BangPatterns #-}
3+
4+
module Main where
5+
6+
import Control.Arrow hiding (loop)
7+
import Criterion.Main
8+
import Data.MyPrelude
9+
import Data.Utils
10+
import Data.Void
11+
import Numeric.Neural
12+
13+
main :: IO ()
14+
main = defaultMain
15+
[ bgroup "white"
16+
[ bench "10/200" $ whnf (w 10) 200
17+
, bench "10/2000" $ whnf (w 10) 2000
18+
, bench "10/20000" $ whnf (w 10) 20000
19+
, bench "100/200" $ whnf (w 100) 200
20+
, bench "100/2000" $ whnf (w 100) 2000
21+
, bench "100/20000" $ whnf (w 100) 20000
22+
, bench "1000/200" $ whnf (w 1000) 200
23+
, bench "1000/2000" $ whnf (w 1000) 2000
24+
, bench "1000/20000" $ whnf (w 1000) 20000
25+
]
26+
, env setupEnv $ \ ~(m, xss) -> bgroup "linear"
27+
[ l m xss 1 5
28+
, l m xss 5 5
29+
, l m xss 10 5
30+
]
31+
]
32+
33+
w :: Int -> Int -> Double
34+
w sampleCount testCount = flip evalRand (mkStdGen 123456) $ do
35+
stats <- mkStats'
36+
samples <- replicateM sampleCount $ mkSample stats
37+
let m = whiten model' samples
38+
xss <- replicateM testCount $ mkSample stats
39+
return $ sum [model m xs | xs <- xss]
40+
41+
where
42+
43+
mkStats' :: MonadRandom m => m (Vector Width (Double, Double))
44+
mkStats' = sequenceA (pure $ (,) <$> getRandomR (-100, 100) <*> getRandomR (0.1, 20))
45+
46+
mkSample :: MonadRandom m => Vector Width (Double, Double) -> m (Vector Width Double)
47+
mkSample = mapM $ uncurry boxMuller'
48+
49+
model' :: Model (Vector Width) Identity Void (Vector Width Double) Double
50+
model' = Model (arr $ Identity . sum) absurd id runIdentity
51+
52+
type Width = 10
53+
54+
l :: M -> [Vector Width' Double] -> Int -> Int -> Benchmark
55+
l m xss batchSize steps = bench (printf "%d/%d" batchSize steps) $ whnf l' steps where
56+
57+
l' :: Int -> Double
58+
l' steps' =
59+
let m' = loop steps' m
60+
xs = pure 0
61+
in modelError m' [(xs, xs)]
62+
63+
loop :: Int -> M -> M
64+
loop 0 m' = m'
65+
loop !n m' =
66+
let m'' = m' `deepseq` snd $ descent m' 0.01 [(xs, xs) | xs <- take batchSize xss]
67+
in loop (pred n) m''
68+
69+
setupEnv :: IO (M, [Vector Width' Double])
70+
setupEnv = return $ flip evalRand (mkStdGen 987654) $ do
71+
m <- modelR $ mkStdModel linearLayer (sqDiff . (fromDouble <$>)) id id
72+
xss <- replicateM 100 $ let r = getRandomR (-5, 5) in sequence $ pure r
73+
return (m, xss)
74+
75+
type M = StdModel (Vector Width') (Vector Width') (Vector Width' Double) (Vector Width' Double)
76+
77+
type Width' = 100

examples/MNIST/MNIST.hs

+41-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
{-# LANGUAGE DataKinds #-}
2+
13
module Main where
24

35
import Codec.Picture
6+
import Control.Arrow
47
import qualified Data.Array as A
58
import Data.MyPrelude
69
import Data.Utils
@@ -9,7 +12,36 @@ import Pipes.GZip (decompress)
912
import qualified Pipes.Prelude as P
1013

1114
main :: IO ()
12-
main = runSafeT (runEffect $ trainSamples >-> P.take 50 >-> consumeSamples "test")
15+
main = do
16+
xs <- runSafeT $ P.toListM (trainSamples >-> P.take 1000)
17+
printf "loaded %d train samples\n" (length xs)
18+
ys <- runSafeT $ P.toListM (testSamples >-> P.take 500)
19+
printf "loaded %d test samples\n" (length ys)
20+
flip evalRandT (mkStdGen 999999) $ do
21+
xs' <- takeR 100 $ fst <$> xs
22+
m <- modelR (whiten mnistModel xs')
23+
runEffect $
24+
simpleBatchP xs 20
25+
>-> descentP m 1 (const 0.1)
26+
>-> reportTSP 1 report
27+
>-> consumeTSP (check ys)
28+
29+
where
30+
31+
report ts = liftIO $ printf "%7d %8.6f %10.8f\n" (tsGeneration ts) (tsEta ts) (tsBatchError ts)
32+
33+
check ys ts = --return $ if tsGeneration ts == 3 then Just () else Nothing
34+
if tsGeneration ts `mod` 5 == 0
35+
then do
36+
let a = accuracy (tsModel ts) ys :: Double
37+
liftIO $ printf "\naccuracy %f\n\n" a
38+
return Nothing
39+
else return Nothing
40+
41+
correct m (img, d) = model m img == d
42+
43+
accuracy m ys = let c = length $ filter (correct m) ys
44+
in fromIntegral c / fromIntegral (length ys)
1345

1446
type Img = Image Pixel8
1547

@@ -45,10 +77,12 @@ testSamples = P.zip (images testImagesFile) (labels testLabelsFile)
4577
writeImg :: MonadIO m => FilePath -> Img -> m ()
4678
writeImg f i = liftIO $ saveTiffImage (f <.> "tiff") (ImageY8 i)
4779

48-
consumeSamples :: MonadIO m => String -> Consumer Sample m ()
49-
consumeSamples f = g (1 :: Int) where
80+
mnistModel :: Classifier (Matrix 28 28) 10 Img Digit
81+
mnistModel = mkStdClassifier c i where
82+
83+
c = f ^>> (tanhLayer :: Layer 784 10) >>> tanhLayer
84+
85+
i img = let m = mgenerate $ \(x, y) -> fromIntegral (pixelAt img x y) in force m
5086

51-
g i = do
52-
(img, l) <- await
53-
writeImg (printf "%s_%05d_%d" f i (fromEnum l)) img
54-
g (succ i)
87+
f :: Matrix 28 28 Analytic -> Vector 784 Analytic
88+
f m = generate $ \w -> m !!! (w `mod` 28, w `div` 28)

neural.cabal

+14
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,16 @@ test-suite neural-doctest
112112
, doctest
113113
, Glob
114114
ghc-options: -Wall -threaded -rtsopts -with-rtsopts=-N -fexcess-precision -optc-O3 -optc-ffast-math
115+
default-language: Haskell2010
115116

117+
benchmark neural-bench
118+
type: exitcode-stdio-1.0
119+
hs-source-dirs: benchmark
120+
main-is: benchmark.hs
121+
build-depends: base >= 4.7 && < 5
122+
, criterion
123+
, neural
124+
ghc-options: -Wall -threaded -rtsopts -with-rtsopts=-N -fexcess-precision -optc-O3 -optc-ffast-math
116125
default-language: Haskell2010
117126

118127
executable iris
@@ -150,3 +159,8 @@ source-repository head
150159
type: git
151160
location: https://github.com/brunjlar/neural.git
152161

162+
source-repository this
163+
type: git
164+
location: https://github.com/brunjlar/neural.git
165+
tag: 0.1.1.0
166+

src/Data/MyPrelude.hs

+4-4
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ This module simply reexports a selection of commonly used standard types and fun
1313
-}
1414

1515
module Data.MyPrelude
16-
( NFData(..)
16+
( NFData(..), deepseq, force
1717
, (&), (^.), (.~), Lens', Getter, to, lens
1818
, when, unless, forM, forM_, void, replicateM, forever, guard
1919
, Identity(..)
2020
, MonadIO(..)
21-
, MonadRandom, getRandom, getRandomR, RandT, runRandT, evalRandT, StdGen, mkStdGen
21+
, MonadRandom, getRandom, getRandomR, Rand, RandT, runRand, evalRand, runRandT, evalRandT, StdGen, mkStdGen
2222
, MonadState(..)
2323
, lift
2424
, State, StateT, modify, runState, evalState, execState, runStateT, evalStateT, execStateT
@@ -36,12 +36,12 @@ module Data.MyPrelude
3636
, printf
3737
) where
3838

39-
import Control.DeepSeq (NFData(..))
39+
import Control.DeepSeq (NFData(..), deepseq, force)
4040
import Control.Lens ((&), (^.), (.~), Lens', Getter, to, lens)
4141
import Control.Monad (when, unless, forM, forM_, void, replicateM, forever, guard)
4242
import Control.Monad.Identity (Identity(..))
4343
import Control.Monad.IO.Class (MonadIO(..))
44-
import Control.Monad.Random (MonadRandom, getRandom, getRandomR, RandT, runRandT, evalRandT, StdGen, mkStdGen)
44+
import Control.Monad.Random (MonadRandom, getRandom, getRandomR, Rand, RandT, runRand, evalRand, runRandT, evalRandT, StdGen, mkStdGen)
4545
import Control.Monad.State.Class (MonadState(..))
4646
import Control.Monad.Trans.Class (lift)
4747
import Control.Monad.Trans.State (State, StateT, modify, runState, evalState, execState, runStateT, evalStateT, execStateT)

src/Data/Utils/Matrix.hs

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
{-# LANGUAGE DeriveTraversable #-}
88
{-# LANGUAGE TypeOperators #-}
99
{-# LANGUAGE TypeFamilies #-}
10+
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
1011

1112
{-|
1213
Module : Data.Utils.Matrix
@@ -38,7 +39,7 @@ import Data.Utils.Vector
3839
-- | @'Matrix' m n a@ is the type of /matrices/ with @m@ rows, @n@ columns and entries of type @a@.
3940
--
4041
newtype Matrix (m :: Nat) (n :: Nat) a = Matrix (Vector m (Vector n a))
41-
deriving (Eq, Show, Functor, Foldable, Traversable)
42+
deriving (Eq, Show, Functor, Foldable, Traversable, NFData)
4243

4344
instance (KnownNat m, KnownNat n) => Applicative (Matrix m n) where
4445

src/Data/Utils/Vector.hs

+5-1
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,18 @@ instance (KnownNat n, Read a) => Read (Vector n a) where
7979
n' = fromIntegral (natVal (Proxy :: Proxy n))
8080
in [(Vector ys, t) | (ys, t) <- xs, length ys == n']
8181

82+
instance (NFData a) => NFData (Vector n a) where
83+
84+
rnf (Vector v) = rnf v
85+
8286
-- | The /scalar product/ of two vectors of the same length.
8387
--
8488
-- >>> :set -XDataKinds
8589
-- >>> cons 1 (cons 2 nil) <%> cons 3 (cons 4 nil) :: Int
8690
-- 11
8791
--
8892
(<%>) :: Num a => Vector n a -> Vector n a -> a
89-
xs <%> ys = sum $ zipWith (*) (toList xs) (toList ys)
93+
Vector v <%> Vector w = V.sum $ V.zipWith (*) v w
9094

9195
-- | The vector of length zero.
9296
nil :: Vector 0 a

src/Numeric/Neural/Model.hs

+38-16
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
{-# LANGUAGE Arrows #-}
99
{-# LANGUAGE GADTs #-}
1010
{-# LANGUAGE KindSignatures #-}
11+
{-# LANGUAGE FlexibleContexts #-}
1112

1213
{-|
1314
Module : Neural.Model
@@ -40,6 +41,7 @@ module Numeric.Neural.Model
4041
, mkStdModel
4142
) where
4243

44+
import Control.Applicative
4345
import Control.Arrow
4446
import Control.Category
4547
import Data.Profunctor
@@ -88,7 +90,7 @@ instance Profunctor (ParamFun t) where dimap = dimapArr
8890
-- In contrast to 'ParamFun', when components are composed, parameters are not shared.
8991
-- Each component carries its own collection of parameters instead.
9092
--
91-
data Component a b = forall t. (Traversable t, Applicative t) => Component
93+
data Component a b = forall t. (Traversable t, Applicative t, NFData (t Double)) => Component
9294
{ weights :: t Double -- ^ the specific parameter values
9395
, compute :: ParamFun t a b -- ^ the encapsulated parameterized function
9496
, initR :: forall m. MonadRandom m => m (t Double) -- ^ randomly sets the parameters
@@ -115,8 +117,16 @@ instance Applicative Empty where
115117

116118
Empty <*> Empty = Empty
117119

120+
instance NFData (Empty a) where
121+
122+
rnf Empty = ()
123+
118124
data Pair s t a = Pair (s a) (t a) deriving (Show, Read, Eq, Ord, Functor, Foldable, Traversable)
119125

126+
instance (NFData (s a), NFData (t a)) => NFData (Pair s t a) where
127+
128+
rnf (Pair xs ys) = rnf xs `seq` rnf ys `seq` ()
129+
120130
instance (Applicative s, Applicative t) => Applicative (Pair s t) where
121131

122132
pure x = Pair (pure x) (pure x)
@@ -161,6 +171,10 @@ instance Applicative (Component a) where pure = pureArr; (<*>) = apArr
161171

162172
instance Profunctor Component where dimap = dimapArr
163173

174+
instance NFData (Component a b) where
175+
176+
rnf (Component ws _ _) = rnf ws
177+
164178
-- | A @'Model' f g a b c@ wraps a @'Component' (f 'Analytic') (g 'Analytic')@
165179
-- and models functions @b -> c@ with "samples" (for model error determination)
166180
-- of type @a@.
@@ -178,6 +192,10 @@ instance Profunctor (Model f g a) where
178192

179193
dimap m n (Model c e i o) = Model c e (i . m) (n . o)
180194

195+
instance NFData (Model f g a b c) where
196+
197+
rnf (Model c _ _ _) = rnf c
198+
181199
-- | A 'Lens' for accessing the component embedded in a model.
182200
--
183201
_component :: Lens' (Model f g a b c) (Component (f Analytic) (g Analytic))
@@ -195,28 +213,29 @@ modelR (Model c e i o) = case c of
195213
ws <- r
196214
return $ Model (Component ws f r) e i o
197215

198-
errFun :: (Functor f, Foldable h, Traversable t)
216+
errFun :: (Functor f, Traversable t)
199217
=> (a -> (f Double, g Analytic -> Analytic))
200-
-> h a
218+
-> a
201219
-> ParamFun t (f Analytic) (g Analytic)
202220
-> (t Analytic -> Analytic)
203-
errFun e xs f = runPF f' xs where
204-
205-
f' = toList ^>> convolve f'' >>^ mean
221+
errFun e x f = runPF f' x where
206222

207-
f'' = proc x -> do
208-
let (x', h) = e x
223+
f' = proc z -> do
224+
let (x', h) = e z
209225
x'' = fromDouble <$> x'
210226
y <- f -< x''
211227
returnA -< h y
212228

229+
modelError' :: Model f g a b c -> a -> Double
230+
modelError' (Model c e _ _) x = case c of
231+
Component ws f _ -> let f' = errFun e x f
232+
f'' = fromJust . fromAnalytic . f' . fmap fromDouble
233+
in f'' ws
234+
213235
-- | Calculates the avarage model error for a "mini-batch" of samples.
214236
--
215237
modelError :: Foldable h => Model f g a b c -> h a -> Double
216-
modelError (Model c e _ _) xs = case c of
217-
Component ws f _ -> let f' = errFun e xs f
218-
f'' = fromJust . fromAnalytic . f' . fmap fromDouble
219-
in f'' ws
238+
modelError m xs = mean $ modelError' m <$> toList xs
220239

221240
-- | Performs one step of gradient descent/ backpropagation on the model,
222241
descent :: (Foldable h)
@@ -226,10 +245,13 @@ descent :: (Foldable h)
226245
-> (Double, Model f g a b c) -- ^ returns the average sample error and the improved model
227246
descent (Model c e i o) eta xs = case c of
228247
Component ws f r ->
229-
let f' = errFun e xs f
230-
(err, ws') = gradient (\w dw -> w - eta * dw) f' ws
231-
c' = Component ws' f r
232-
m = Model c' e i o
248+
let f' x = gradient (\_ dw -> dw) (errFun e x f) ws
249+
ys = f' <$> toList xs
250+
err = mean $ fst <$> ys
251+
grad = (* eta) . mean . getZipList <$> sequenceA (ZipList (snd <$> ys))
252+
ws' = (-) <$> ws <*> grad
253+
c' = Component ws' f r
254+
m = Model c' e i o
233255
in (err, m)
234256

235257
-- | A type abbreviation for the most common type of models, where samples are just input-output tuples.

src/Numeric/Neural/Normalization.hs

+4-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
{-# LANGUAGE RankNTypes #-}
44
{-# LANGUAGE TypeOperators #-}
55
{-# LANGUAGE DataKinds #-}
6+
{-# LANGUAGE BangPatterns #-}
67

78
{-|
89
Module : Numeric.Neural.Normalization
@@ -155,8 +156,8 @@ white xss = ((w <$> sequenceA xss) <*>) where
155156

156157
w xs = case toList xs of
157158
[] -> id
158-
xs' -> let (_, m, v) = countMeanVar xs'
159-
s = if v == 0 then 1 else 1 / sqrt v
159+
xs' -> let (_, !m, !v) = countMeanVar xs'
160+
!s = if v == 0 then 1 else 1 / sqrt v
160161
in \x -> (x - m) * s
161162

162163
-- | Modifies a 'Model' by whitening the input before feeding it into the embedded component.
@@ -182,4 +183,4 @@ mkStdClassifier :: (Functor f, KnownNat n, Enum c)
182183
=> Component (f Analytic) (Vector n Analytic) -- ^ the embedded component
183184
-> (b -> f Double) -- ^ converts input
184185
-> Classifier f n b c
185-
mkStdClassifier c i = mkStdModel (c >>^ softmax) crossEntropyError i decode1ofN where
186+
mkStdClassifier c i = mkStdModel (c >>^ softmax) crossEntropyError i decode1ofN

0 commit comments

Comments
 (0)