Skip to content

Commit e928ed0

Browse files
author
Bart Schuurmans
committed
Hide implementation details of extending a graph in FFI module
1 parent f8edd59 commit e928ed0

File tree

3 files changed

+19
-29
lines changed

3 files changed

+19
-29
lines changed

tensorflow/src/TensorFlow/Internal/FFI.hs

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,7 @@ module TensorFlow.Internal.FFI
2727
, Raw.SessionOptions
2828

2929
, Raw.Graph
30-
, Raw.graphOperationByName
31-
, resolveOutput
32-
, resolveOperation
33-
, graphImportGraphDef
34-
, forGraphOperations_
35-
36-
, Raw.Operation
37-
, Raw.operationName
38-
, Raw.operationNumOutputs
39-
40-
, Raw.ImportGraphDefOptions
41-
, Raw.importGraphDefOptionsAddInputMapping
42-
43-
, Raw.Output(..)
30+
, extendGraph
4431

4532
, TensorData(..)
4633
, setSessionConfig
@@ -60,6 +47,7 @@ import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask
6047
import Control.Monad.IO.Class (MonadIO, liftIO)
6148
import Data.Bits (Bits, toIntegralSized)
6249
import Data.Int (Int64)
50+
import Data.Foldable (for_)
6351
import Data.Maybe (fromMaybe)
6452
import Data.Typeable (Typeable)
6553
import Data.Word (Word8)
@@ -174,6 +162,21 @@ forGraphOperations_ graph f = with 0 go
174162
Raw.Operation ptr | ptr == nullPtr -> return ()
175163
_ -> f op >> go indexPtr -- indexPtr is modified by Raw.graphNextOperation.
176164

165+
extendGraph :: Raw.Graph -> GraphDef -> IO ()
166+
extendGraph graph graphDef =
167+
graphImportGraphDef graph graphDef $ \opts ->
168+
-- All inputs of the nodes in the GraphDef should either refer to
169+
-- other nodes in the GraphDef, or be mapped to nodes already in
170+
-- the Graph by adding an input mapping.
171+
-- We add an input mapping for all existing nodes in the Graph in
172+
-- case they are referenced in the GraphDef.
173+
forGraphOperations_ graph $ \op -> do
174+
srcName <- Raw.operationName op
175+
numOutputs <- Raw.operationNumOutputs op
176+
for_ [0..numOutputs] $ \srcIndex -> do
177+
let dst = Raw.Output op (safeConvert srcIndex)
178+
with dst $ Raw.importGraphDefOptionsAddInputMapping opts srcName srcIndex
179+
177180
run :: Raw.Session
178181
-> Raw.Graph
179182
-> [(String, TensorData)] -- ^ Inputs.

tensorflow/src/TensorFlow/Internal/Raw.chs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ instance Storable Operation where
7676
-- Output.
7777
data Output = Output
7878
{ outputOperation :: Operation
79-
, outputIndex :: Int
79+
, outputIndex :: CInt
8080
}
8181
{# pointer *TF_Output as OutputPtr -> Output #}
8282

tensorflow/src/TensorFlow/Session.hs

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,8 @@ import Control.Monad.Trans.Class (MonadTrans, lift)
4545
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
4646
import Data.ByteString (ByteString)
4747
import Data.Default (Default, def)
48-
import Data.Foldable (for_)
4948
import Data.ProtoLens (showMessage)
5049
import Data.Set (Set)
51-
import Foreign.Marshal.Utils (with)
5250
import Lens.Family2 (Lens', (^.), (&), (.~))
5351
import Lens.Family2.Unchecked (lens)
5452
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
@@ -156,18 +154,7 @@ extend = do
156154
unless (null nodesToExtend) $ liftIO $ do
157155
let graphDef = (defMessage :: GraphDef) & node .~ nodesToExtend
158156
trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
159-
FFI.graphImportGraphDef graph graphDef $ \opts ->
160-
-- All inputs of the nodes in the GraphDef should either refer to
161-
-- other nodes in the GraphDef, or be mapped to nodes already in
162-
-- the Graph by adding an input mapping.
163-
-- We add an input mapping for all existing nodes in the Graph in
164-
-- case they are referenced in the GraphDef.
165-
FFI.forGraphOperations_ graph $ \op -> do
166-
srcName <- FFI.operationName op
167-
numOutputs <- FFI.operationNumOutputs op
168-
for_ [0..numOutputs] $ \srcIndex -> do
169-
let dst = FFI.Output op srcIndex
170-
with dst $ FFI.importGraphDefOptionsAddInputMapping opts srcName srcIndex
157+
FFI.extendGraph graph graphDef
171158
-- Now that all the nodes are created, run the initializers.
172159
initializers <- build flushInitializers
173160
unless (null initializers) $

0 commit comments

Comments
 (0)