8
8
{-# LANGUAGE Arrows #-}
9
9
{-# LANGUAGE GADTs #-}
10
10
{-# LANGUAGE KindSignatures #-}
11
+ {-# LANGUAGE FlexibleContexts #-}
11
12
12
13
{-|
13
14
Module : Neural.Model
@@ -40,6 +41,7 @@ module Numeric.Neural.Model
40
41
, mkStdModel
41
42
) where
42
43
44
+ import Control.Applicative
43
45
import Control.Arrow
44
46
import Control.Category
45
47
import Data.Profunctor
@@ -88,7 +90,7 @@ instance Profunctor (ParamFun t) where dimap = dimapArr
88
90
-- In contrast to 'ParamFun', when components are composed, parameters are not shared.
89
91
-- Each component carries its own collection of parameters instead.
90
92
--
91
- data Component a b = forall t . (Traversable t , Applicative t ) => Component
93
+ data Component a b = forall t . (Traversable t , Applicative t , NFData ( t Double ) ) => Component
92
94
{ weights :: t Double -- ^ the specific parameter values
93
95
, compute :: ParamFun t a b -- ^ the encapsulated parameterized function
94
96
, initR :: forall m . MonadRandom m => m (t Double ) -- ^ randomly sets the parameters
@@ -115,8 +117,16 @@ instance Applicative Empty where
115
117
116
118
Empty <*> Empty = Empty
117
119
120
+ instance NFData (Empty a ) where
121
+
122
+ rnf Empty = ()
123
+
118
124
data Pair s t a = Pair (s a ) (t a ) deriving (Show , Read , Eq , Ord , Functor , Foldable , Traversable )
119
125
126
+ instance (NFData (s a ), NFData (t a )) => NFData (Pair s t a ) where
127
+
128
+ rnf (Pair xs ys) = rnf xs `seq` rnf ys `seq` ()
129
+
120
130
instance (Applicative s , Applicative t ) => Applicative (Pair s t ) where
121
131
122
132
pure x = Pair (pure x) (pure x)
@@ -161,6 +171,10 @@ instance Applicative (Component a) where pure = pureArr; (<*>) = apArr
161
171
162
172
instance Profunctor Component where dimap = dimapArr
163
173
174
+ instance NFData (Component a b ) where
175
+
176
+ rnf (Component ws _ _) = rnf ws
177
+
164
178
-- | A @'Model' f g a b c@ wraps a @'Component' (f 'Analytic') (g 'Analytic')@
165
179
-- and models functions @b -> c@ with "samples" (for model error determination)
166
180
-- of type @a@.
@@ -178,6 +192,10 @@ instance Profunctor (Model f g a) where
178
192
179
193
dimap m n (Model c e i o) = Model c e (i . m) (n . o)
180
194
195
+ instance NFData (Model f g a b c ) where
196
+
197
+ rnf (Model c _ _ _) = rnf c
198
+
181
199
-- | A 'Lens' for accessing the component embedded in a model.
182
200
--
183
201
_component :: Lens' (Model f g a b c ) (Component (f Analytic ) (g Analytic ))
@@ -195,28 +213,29 @@ modelR (Model c e i o) = case c of
195
213
ws <- r
196
214
return $ Model (Component ws f r) e i o
197
215
198
- errFun :: (Functor f , Foldable h , Traversable t )
216
+ errFun :: (Functor f , Traversable t )
199
217
=> (a -> (f Double , g Analytic -> Analytic ))
200
- -> h a
218
+ -> a
201
219
-> ParamFun t (f Analytic ) (g Analytic )
202
220
-> (t Analytic -> Analytic )
203
- errFun e xs f = runPF f' xs where
204
-
205
- f' = toList ^>> convolve f'' >>^ mean
221
+ errFun e x f = runPF f' x where
206
222
207
- f'' = proc x -> do
208
- let (x', h) = e x
223
+ f' = proc z -> do
224
+ let (x', h) = e z
209
225
x'' = fromDouble <$> x'
210
226
y <- f -< x''
211
227
returnA -< h y
212
228
229
+ modelError' :: Model f g a b c -> a -> Double
230
+ modelError' (Model c e _ _) x = case c of
231
+ Component ws f _ -> let f' = errFun e x f
232
+ f'' = fromJust . fromAnalytic . f' . fmap fromDouble
233
+ in f'' ws
234
+
213
235
-- | Calculates the avarage model error for a "mini-batch" of samples.
214
236
--
215
237
modelError :: Foldable h => Model f g a b c -> h a -> Double
216
- modelError (Model c e _ _) xs = case c of
217
- Component ws f _ -> let f' = errFun e xs f
218
- f'' = fromJust . fromAnalytic . f' . fmap fromDouble
219
- in f'' ws
238
+ modelError m xs = mean $ modelError' m <$> toList xs
220
239
221
240
-- | Performs one step of gradient descent/ backpropagation on the model,
222
241
descent :: (Foldable h )
@@ -226,10 +245,13 @@ descent :: (Foldable h)
226
245
-> (Double , Model f g a b c ) -- ^ returns the average sample error and the improved model
227
246
descent (Model c e i o) eta xs = case c of
228
247
Component ws f r ->
229
- let f' = errFun e xs f
230
- (err, ws') = gradient (\ w dw -> w - eta * dw) f' ws
231
- c' = Component ws' f r
232
- m = Model c' e i o
248
+ let f' x = gradient (\ _ dw -> dw) (errFun e x f) ws
249
+ ys = f' <$> toList xs
250
+ err = mean $ fst <$> ys
251
+ grad = (* eta) . mean . getZipList <$> sequenceA (ZipList (snd <$> ys))
252
+ ws' = (-) <$> ws <*> grad
253
+ c' = Component ws' f r
254
+ m = Model c' e i o
233
255
in (err, m)
234
256
235
257
-- | A type abbreviation for the most common type of models, where samples are just input-output tuples.
0 commit comments