Skip to content

Commit 0358b4a

Browse files
committed
Iris dataset example
1 parent 2a3cc41 commit 0358b4a

File tree

6 files changed

+293
-5
lines changed

6 files changed

+293
-5
lines changed

examples/iris/data.csv

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
5.1,3.5,1.4,0.2,Iris-setosa
2+
4.9,3.0,1.4,0.2,Iris-setosa
3+
4.7,3.2,1.3,0.2,Iris-setosa
4+
4.6,3.1,1.5,0.2,Iris-setosa
5+
5.0,3.6,1.4,0.2,Iris-setosa
6+
5.4,3.9,1.7,0.4,Iris-setosa
7+
4.6,3.4,1.4,0.3,Iris-setosa
8+
5.0,3.4,1.5,0.2,Iris-setosa
9+
4.4,2.9,1.4,0.2,Iris-setosa
10+
4.9,3.1,1.5,0.1,Iris-setosa
11+
5.4,3.7,1.5,0.2,Iris-setosa
12+
4.8,3.4,1.6,0.2,Iris-setosa
13+
4.8,3.0,1.4,0.1,Iris-setosa
14+
4.3,3.0,1.1,0.1,Iris-setosa
15+
5.8,4.0,1.2,0.2,Iris-setosa
16+
5.7,4.4,1.5,0.4,Iris-setosa
17+
5.4,3.9,1.3,0.4,Iris-setosa
18+
5.1,3.5,1.4,0.3,Iris-setosa
19+
5.7,3.8,1.7,0.3,Iris-setosa
20+
5.1,3.8,1.5,0.3,Iris-setosa
21+
5.4,3.4,1.7,0.2,Iris-setosa
22+
5.1,3.7,1.5,0.4,Iris-setosa
23+
4.6,3.6,1.0,0.2,Iris-setosa
24+
5.1,3.3,1.7,0.5,Iris-setosa
25+
4.8,3.4,1.9,0.2,Iris-setosa
26+
5.0,3.0,1.6,0.2,Iris-setosa
27+
5.0,3.4,1.6,0.4,Iris-setosa
28+
5.2,3.5,1.5,0.2,Iris-setosa
29+
5.2,3.4,1.4,0.2,Iris-setosa
30+
4.7,3.2,1.6,0.2,Iris-setosa
31+
4.8,3.1,1.6,0.2,Iris-setosa
32+
5.4,3.4,1.5,0.4,Iris-setosa
33+
5.2,4.1,1.5,0.1,Iris-setosa
34+
5.5,4.2,1.4,0.2,Iris-setosa
35+
4.9,3.1,1.5,0.2,Iris-setosa
36+
5.0,3.2,1.2,0.2,Iris-setosa
37+
5.5,3.5,1.3,0.2,Iris-setosa
38+
4.9,3.6,1.4,0.1,Iris-setosa
39+
4.4,3.0,1.3,0.2,Iris-setosa
40+
5.1,3.4,1.5,0.2,Iris-setosa
41+
5.0,3.5,1.3,0.3,Iris-setosa
42+
4.5,2.3,1.3,0.3,Iris-setosa
43+
4.4,3.2,1.3,0.2,Iris-setosa
44+
5.0,3.5,1.6,0.6,Iris-setosa
45+
5.1,3.8,1.9,0.4,Iris-setosa
46+
4.8,3.0,1.4,0.3,Iris-setosa
47+
5.1,3.8,1.6,0.2,Iris-setosa
48+
4.6,3.2,1.4,0.2,Iris-setosa
49+
5.3,3.7,1.5,0.2,Iris-setosa
50+
5.0,3.3,1.4,0.2,Iris-setosa
51+
7.0,3.2,4.7,1.4,Iris-versicolor
52+
6.4,3.2,4.5,1.5,Iris-versicolor
53+
6.9,3.1,4.9,1.5,Iris-versicolor
54+
5.5,2.3,4.0,1.3,Iris-versicolor
55+
6.5,2.8,4.6,1.5,Iris-versicolor
56+
5.7,2.8,4.5,1.3,Iris-versicolor
57+
6.3,3.3,4.7,1.6,Iris-versicolor
58+
4.9,2.4,3.3,1.0,Iris-versicolor
59+
6.6,2.9,4.6,1.3,Iris-versicolor
60+
5.2,2.7,3.9,1.4,Iris-versicolor
61+
5.0,2.0,3.5,1.0,Iris-versicolor
62+
5.9,3.0,4.2,1.5,Iris-versicolor
63+
6.0,2.2,4.0,1.0,Iris-versicolor
64+
6.1,2.9,4.7,1.4,Iris-versicolor
65+
5.6,2.9,3.6,1.3,Iris-versicolor
66+
6.7,3.1,4.4,1.4,Iris-versicolor
67+
5.6,3.0,4.5,1.5,Iris-versicolor
68+
5.8,2.7,4.1,1.0,Iris-versicolor
69+
6.2,2.2,4.5,1.5,Iris-versicolor
70+
5.6,2.5,3.9,1.1,Iris-versicolor
71+
5.9,3.2,4.8,1.8,Iris-versicolor
72+
6.1,2.8,4.0,1.3,Iris-versicolor
73+
6.3,2.5,4.9,1.5,Iris-versicolor
74+
6.1,2.8,4.7,1.2,Iris-versicolor
75+
6.4,2.9,4.3,1.3,Iris-versicolor
76+
6.6,3.0,4.4,1.4,Iris-versicolor
77+
6.8,2.8,4.8,1.4,Iris-versicolor
78+
6.7,3.0,5.0,1.7,Iris-versicolor
79+
6.0,2.9,4.5,1.5,Iris-versicolor
80+
5.7,2.6,3.5,1.0,Iris-versicolor
81+
5.5,2.4,3.8,1.1,Iris-versicolor
82+
5.5,2.4,3.7,1.0,Iris-versicolor
83+
5.8,2.7,3.9,1.2,Iris-versicolor
84+
6.0,2.7,5.1,1.6,Iris-versicolor
85+
5.4,3.0,4.5,1.5,Iris-versicolor
86+
6.0,3.4,4.5,1.6,Iris-versicolor
87+
6.7,3.1,4.7,1.5,Iris-versicolor
88+
6.3,2.3,4.4,1.3,Iris-versicolor
89+
5.6,3.0,4.1,1.3,Iris-versicolor
90+
5.5,2.5,4.0,1.3,Iris-versicolor
91+
5.5,2.6,4.4,1.2,Iris-versicolor
92+
6.1,3.0,4.6,1.4,Iris-versicolor
93+
5.8,2.6,4.0,1.2,Iris-versicolor
94+
5.0,2.3,3.3,1.0,Iris-versicolor
95+
5.6,2.7,4.2,1.3,Iris-versicolor
96+
5.7,3.0,4.2,1.2,Iris-versicolor
97+
5.7,2.9,4.2,1.3,Iris-versicolor
98+
6.2,2.9,4.3,1.3,Iris-versicolor
99+
5.1,2.5,3.0,1.1,Iris-versicolor
100+
5.7,2.8,4.1,1.3,Iris-versicolor
101+
6.3,3.3,6.0,2.5,Iris-virginica
102+
5.8,2.7,5.1,1.9,Iris-virginica
103+
7.1,3.0,5.9,2.1,Iris-virginica
104+
6.3,2.9,5.6,1.8,Iris-virginica
105+
6.5,3.0,5.8,2.2,Iris-virginica
106+
7.6,3.0,6.6,2.1,Iris-virginica
107+
4.9,2.5,4.5,1.7,Iris-virginica
108+
7.3,2.9,6.3,1.8,Iris-virginica
109+
6.7,2.5,5.8,1.8,Iris-virginica
110+
7.2,3.6,6.1,2.5,Iris-virginica
111+
6.5,3.2,5.1,2.0,Iris-virginica
112+
6.4,2.7,5.3,1.9,Iris-virginica
113+
6.8,3.0,5.5,2.1,Iris-virginica
114+
5.7,2.5,5.0,2.0,Iris-virginica
115+
5.8,2.8,5.1,2.4,Iris-virginica
116+
6.4,3.2,5.3,2.3,Iris-virginica
117+
6.5,3.0,5.5,1.8,Iris-virginica
118+
7.7,3.8,6.7,2.2,Iris-virginica
119+
7.7,2.6,6.9,2.3,Iris-virginica
120+
6.0,2.2,5.0,1.5,Iris-virginica
121+
6.9,3.2,5.7,2.3,Iris-virginica
122+
5.6,2.8,4.9,2.0,Iris-virginica
123+
7.7,2.8,6.7,2.0,Iris-virginica
124+
6.3,2.7,4.9,1.8,Iris-virginica
125+
6.7,3.3,5.7,2.1,Iris-virginica
126+
7.2,3.2,6.0,1.8,Iris-virginica
127+
6.2,2.8,4.8,1.8,Iris-virginica
128+
6.1,3.0,4.9,1.8,Iris-virginica
129+
6.4,2.8,5.6,2.1,Iris-virginica
130+
7.2,3.0,5.8,1.6,Iris-virginica
131+
7.4,2.8,6.1,1.9,Iris-virginica
132+
7.9,3.8,6.4,2.0,Iris-virginica
133+
6.4,2.8,5.6,2.2,Iris-virginica
134+
6.3,2.8,5.1,1.5,Iris-virginica
135+
6.1,2.6,5.6,1.4,Iris-virginica
136+
7.7,3.0,6.1,2.3,Iris-virginica
137+
6.3,3.4,5.6,2.4,Iris-virginica
138+
6.4,3.1,5.5,1.8,Iris-virginica
139+
6.0,3.0,4.8,1.8,Iris-virginica
140+
6.9,3.1,5.4,2.1,Iris-virginica
141+
6.7,3.1,5.6,2.4,Iris-virginica
142+
6.9,3.1,5.1,2.3,Iris-virginica
143+
5.8,2.7,5.1,1.9,Iris-virginica
144+
6.8,3.2,5.9,2.3,Iris-virginica
145+
6.7,3.3,5.7,2.5,Iris-virginica
146+
6.7,3.0,5.2,2.3,Iris-virginica
147+
6.3,2.5,5.0,1.9,Iris-virginica
148+
6.5,3.0,5.2,2.0,Iris-virginica
149+
6.2,3.4,5.4,2.3,Iris-virginica
150+
5.9,3.0,5.1,1.8,Iris-virginica

examples/iris/iris.hs

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
{-# LANGUAGE OverloadedStrings #-}
2+
{-# LANGUAGE Arrows #-}
3+
{-# LANGUAGE DataKinds #-}
4+
5+
import Control.Applicative
6+
import Control.Arrow hiding (loop)
7+
import Data.Attoparsec.Text
8+
import qualified Data.Text as T
9+
import MyPrelude
10+
import Neural
11+
import Utils
12+
13+
main :: IO ()
14+
main = do
15+
xs <- readSamples
16+
printf "read %d samples\n" (length xs)
17+
evalComponentM component (mkStdGen 123456) $ do
18+
19+
let getError' = do
20+
c <- get
21+
let Just e = fromAnalytic $ getError c err xs
22+
23+
return e
24+
25+
let getQuota = do
26+
c <- get
27+
let c' = predict c
28+
n = length $ filter (uncurry (==)) [(activate c' a, i) | (a, i) <- xs]
29+
q = fromIntegral n / fromIntegral (length xs) :: Double
30+
return q
31+
32+
let loop i = do
33+
batch <- takeR batchSize xs
34+
_ <- descentM err eta batch
35+
if i `mod` 1000 == 0
36+
then do
37+
e <- getError'
38+
q <- getQuota
39+
liftIO $ printf "%6d %8.6f %6.4f\n" i e q
40+
unless (q >= 0.99) $ loop (succ i)
41+
else loop (succ i)
42+
43+
randomizeM
44+
getError' >>= \e -> liftIO $ printf "initial error is %f\n" e
45+
loop (1 :: Int)
46+
47+
batchSize :: Int
48+
batchSize = 10
49+
50+
eta :: Double
51+
eta = 0.0001
52+
53+
data Iris = Setosa | Versicolor | Virginica deriving (Show, Read, Eq, Ord, Enum)
54+
55+
data Attributes = Attributes Double Double Double Double deriving (Show, Read, Eq, Ord)
56+
57+
type Sample = (Attributes, Iris)
58+
59+
irisParser :: Parser Iris
60+
irisParser = string "Iris-setosa" *> return Setosa
61+
<|> string "Iris-versicolor" *> return Versicolor
62+
<|> string "Iris-virginica" *> return Virginica
63+
64+
65+
66+
sampleParser :: Parser Sample
67+
sampleParser = f <$> (double <* char ',')
68+
<*> (double <* char ',')
69+
<*> (double <* char ',')
70+
<*> (double <* char ',')
71+
<*> irisParser
72+
where f sl sw pl pw i = (Attributes sl sw pl pw, i)
73+
74+
readSamples :: IO [Sample]
75+
readSamples = do
76+
ls <- T.lines . T.pack <$> readFile ("examples" </> "iris" </> "data" <.> "csv")
77+
return $ f <$> ls
78+
79+
where
80+
81+
f l = let Right x = parseOnly sampleParser l in x
82+
83+
component :: Component Attributes (Vector 3 Analytic)
84+
component = let l1 = tanhLayer :: Layer 4 2
85+
l2 = tanhLayer :: Layer 2 3
86+
f (Attributes sl sw pl pw) = cons sl (cons sw (cons pl (cons pw nil)))
87+
in f ^>> fmap fromDouble ^>> l1 >>> l2 >>^ softmax
88+
89+
err :: Err Attributes (Vector 3 Analytic) Sample
90+
err c = proc (a, i) -> do
91+
y <- c -< a
92+
let y' = case i of
93+
Setosa -> cons 1 (cons 0 (cons 0 nil))
94+
Versicolor -> cons 0 (cons 1 (cons 0 nil))
95+
Virginica -> cons 0 (cons 0 (cons 1 nil))
96+
d = (-) <$> y <*> y'
97+
e = d <%> d
98+
returnA -< e
99+
100+
predict :: Component Attributes (Vector 3 Analytic) -> Component Attributes Iris
101+
predict c = c >>^ f where
102+
103+
f ys = let Just y0 = ys !? 0
104+
Just y1 = ys !? 1
105+
Just y2 = ys !? 2
106+
in if y0 >= max y1 y2
107+
then Setosa
108+
else if y1 >= y2 then Versicolor
109+
else Virginica
110+

neural.cabal

+13-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ test-suite neural-test
6363
other-modules: Neural.DescentSpec
6464
, Utils.MatrixSpec
6565
, Utils.VectorSpec
66-
ghc-options: -Wall -threaded
66+
ghc-options: -Wall -threaded -rtsopts -O2 -with-rtsopts=-N -fexcess-precision -optc-O3 -optc-ffast-math
6767
default-language: Haskell2010
6868

6969
test-suite neural-doctest
@@ -73,7 +73,18 @@ test-suite neural-doctest
7373
build-depends: base >= 4.7 && < 5
7474
, doctest
7575
, neural
76-
ghc-options: -Wall -threaded
76+
ghc-options: -Wall -threaded -rtsopts -O2 -with-rtsopts=-N -fexcess-precision -optc-O3 -optc-ffast-math
77+
78+
default-language: Haskell2010
79+
80+
executable iris
81+
hs-source-dirs: examples/iris
82+
main-is: iris.hs
83+
build-depends: base >= 4.7 && < 5
84+
, attoparsec
85+
, neural
86+
, text
87+
ghc-options: -Wall -threaded -rtsopts -O2 -with-rtsopts=-N -fexcess-precision -optc-O3 -optc-ffast-math
7788
default-language: Haskell2010
7889

7990
source-repository head

src/Neural/Component.hs

-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
{-# LANGUAGE DeriveFunctor #-}
66
{-# LANGUAGE DeriveFoldable #-}
77
{-# LANGUAGE DeriveTraversable #-}
8-
{-# LANGUAGE Arrows #-}
98

109
{-|
1110
Module : Neural.Component

src/Neural/Descent.hs

+9-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ A special case of this is the /backpropagation algorithm/ for neural networks.
1717

1818
module Neural.Descent
1919
( Err
20+
, getError
2021
, descent'
2122
, descent
2223
, descentM
@@ -36,6 +37,12 @@ import Utils.Statistics (mean)
3637
--
3738
type Err a b c = forall t. Component' t a b -> Component' t c Analytic
3839

40+
-- | Computes the average error of a component for an error transformation and collection of samples.
41+
--
42+
getError :: (Functor f, Foldable f) => Component a b -> Err a b c -> f c -> Analytic
43+
getError (Component ws c _) err xs = let c' = convolve (err c) >>^ toList >>^ mean
44+
in runC c' xs $ fromDouble <$> ws
45+
3946
-- | This function performs one step of gradient descent for the given component and error transformation.
4047
--
4148
descent' :: Component a b -- ^ the component whose error should be decreased
@@ -54,9 +61,9 @@ descent :: Component a b -- ^ the component whose error should be d
5461
-> Double -- ^ the learning rate
5562
-> [c] -- ^ the mini batch of samples
5663
-> (Double, Component a b) -- ^ the mean error and the improved component
57-
descent c err eta xs = descent' c err' eta xs where
64+
descent c err = descent' c err' where
5865

59-
err' = \c' -> convolve (err c') >>^ mean
66+
err' c' = convolve (err c') >>^ mean
6067

6168
-- | This is the monadic version of 'descent': It performs one step of gradient descent on a 'mini batch'
6269
-- of samples and implicitly updates the state-component's weights.

src/Neural/Layer.hs

+11
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ module Neural.Layer
77
, linearLayer
88
, layer
99
, tanhLayer
10+
, logisticLayer
11+
, softmax
1012
) where
1113

1214
import Control.Arrow
@@ -38,3 +40,12 @@ layer f = arr (fmap f) . linearLayer
3840

3941
tanhLayer :: (KnownNat i, KnownNat o) => Layer i o
4042
tanhLayer = layer tanh
43+
44+
logisticLayer :: (KnownNat i, KnownNat o) => Layer i o
45+
logisticLayer = layer $ \x -> 1 / (1 + exp (- x))
46+
47+
softmax :: (Floating a, Functor f, Foldable f) => f a -> f a
48+
softmax xs = let xs' = exp <$> xs
49+
s = sum xs'
50+
in (/ s) <$> xs'
51+

0 commit comments

Comments
 (0)