Skip to content

Commit f8edd59

Browse files
author
Bart Schuurmans
committed
Pass input maping to TF_GraphImportGraphDef for existing nodes
1 parent d36626b commit f8edd59

File tree

3 files changed

+112
-87
lines changed

3 files changed

+112
-87
lines changed

tensorflow/src/TensorFlow/Internal/FFI.hs

Lines changed: 71 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,29 @@
1919
module TensorFlow.Internal.FFI
2020
( TensorFlowException(..)
2121
, Raw.Session
22-
, Raw.SessionOptions
2322
, withSession
23+
, run
24+
2425
, SessionAction
26+
27+
, Raw.SessionOptions
28+
2529
, Raw.Graph
26-
, importGraphDef
27-
, run
30+
, Raw.graphOperationByName
31+
, resolveOutput
32+
, resolveOperation
33+
, graphImportGraphDef
34+
, forGraphOperations_
35+
36+
, Raw.Operation
37+
, Raw.operationName
38+
, Raw.operationNumOutputs
39+
40+
, Raw.ImportGraphDefOptions
41+
, Raw.importGraphDefOptionsAddInputMapping
42+
43+
, Raw.Output(..)
44+
2845
, TensorData(..)
2946
, setSessionConfig
3047
, setSessionTarget
@@ -46,9 +63,8 @@ import Data.Int (Int64)
4663
import Data.Maybe (fromMaybe)
4764
import Data.Typeable (Typeable)
4865
import Data.Word (Word8)
49-
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
50-
import Foreign.C.String (CString)
51-
import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr)
66+
import Foreign (Ptr, FunPtr, nullPtr, castPtr, with)
67+
import Foreign.ForeignPtr (newForeignPtr_)
5268
import Foreign.Marshal.Alloc (free)
5369
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
5470
import System.IO.Unsafe (unsafePerformIO)
@@ -137,35 +153,50 @@ shutDownRunner r = do
137153
-- TODO(gnezdo): manage exceptions better than print.
138154
either print (const (return ())) =<< waitCatch r
139155

140-
importGraphDef :: Raw.Graph -> GraphDef -> IO ()
141-
importGraphDef graph pb =
156+
graphImportGraphDef :: Raw.Graph
157+
-> GraphDef
158+
-> (Raw.ImportGraphDefOptions -> IO ())
159+
-> IO ()
160+
graphImportGraphDef graph pb optionSetter =
142161
useProtoAsBuffer pb $ \buffer ->
143-
bracket Raw.newImportGraphDefOptions Raw.deleteImportGraphDefOptions $ \importGraphDefOptions ->
144-
checkStatus $ Raw.importGraphDef graph buffer importGraphDefOptions
162+
bracket Raw.newImportGraphDefOptions Raw.deleteImportGraphDefOptions $ \importGraphDefOptions -> do
163+
optionSetter importGraphDefOptions
164+
checkStatus $ Raw.graphImportGraphDef graph buffer importGraphDefOptions
165+
166+
forGraphOperations_ :: Raw.Graph
167+
-> (Raw.Operation -> IO b)
168+
-> IO ()
169+
forGraphOperations_ graph f = with 0 go
170+
where
171+
go indexPtr = do
172+
op <- Raw.graphNextOperation graph indexPtr
173+
case op of
174+
Raw.Operation ptr | ptr == nullPtr -> return ()
175+
_ -> f op >> go indexPtr -- indexPtr is modified by Raw.graphNextOperation.
145176

146177
run :: Raw.Session
147178
-> Raw.Graph
148-
-> [(B.ByteString, TensorData)] -- ^ Inputs.
149-
-> [B.ByteString] -- ^ Outputs.
150-
-> [B.ByteString] -- ^ Target operations.
179+
-> [(String, TensorData)] -- ^ Inputs.
180+
-> [String] -- ^ Outputs.
181+
-> [String] -- ^ Target operations.
151182
-> IO [TensorData]
152183
run session graph inputNamesData outputNames targetNames = do
153184
-- Use mask to avoid leaking input tensors before they are passed to 'run'
154185
-- and output tensors before they are passed to 'createTensorData'.
155186
mask_ $
156187
-- Inputs.
157-
mapM (resolveOutput . fst) inputNamesData >>= \inputs ->
188+
mapM (resolveOutput graph . fst) inputNamesData >>= \inputs ->
158189
withArrayLen inputs $ \nInputs cInputs ->
159190
mapM (createRawTensor . snd) inputNamesData >>= \inputTensors ->
160191
withArrayLen inputTensors $ \_ cInputTensors ->
161192
-- Outputs.
162-
mapM resolveOutput outputNames >>= \outputs ->
193+
mapM (resolveOutput graph) outputNames >>= \outputs ->
163194
withArrayLen outputs $ \nOutputs cOutputs ->
164195
-- outputTensors is an array of null Tensor pointers that will be filled
165196
-- by the call to Raw.run.
166197
withArrayLen (replicate nOutputs nullTensor) $ \_ cOutputTensors ->
167198
-- Target operations.
168-
mapM resolveOperation targetNames >>= \targets ->
199+
mapM (resolveOperation graph) targetNames >>= \targets ->
169200
withArrayLen targets $ \nTargets cTargets -> do
170201
checkStatus $ Raw.run
171202
session
@@ -178,32 +209,33 @@ run session graph inputNamesData outputNames targetNames = do
178209
outTensors <- peekArray nOutputs cOutputTensors
179210
mapM createTensorData outTensors
180211
where
181-
resolveOutput :: B.ByteString -> IO Raw.Output
182-
resolveOutput name = do
183-
let (opName, idx) = parseName name
184-
op <- resolveOperation opName
185-
pure $ Raw.Output op idx
186-
187-
resolveOperation :: B.ByteString -> IO Raw.Operation
188-
resolveOperation name = do
189-
op <- B.useAsCString name $ Raw.graphOperationByName graph
190-
case op of
191-
Raw.Operation ptr | ptr == nullPtr -> throwM exception
192-
_ -> pure op
193-
where
194-
exception =
195-
let msg = "Operation not found in graph: " <> (T.pack $ show name)
196-
in TensorFlowException Raw.TF_INVALID_ARGUMENT msg
197-
198-
parseName :: B.ByteString -> (B.ByteString, CInt)
199-
parseName name =
200-
case break (== ':') (C.unpack name) of
201-
(name, ':':idxStr) | idx <- read idxStr
202-
-> (C.pack name, idx)
203-
_ -> (name, 0)
204212

205213
nullTensor = Raw.Tensor nullPtr
206214

215+
resolveOutput :: Raw.Graph -> String -> IO Raw.Output
216+
resolveOutput graph name = do
217+
let (opName, idx) = parseName name
218+
op <- resolveOperation graph opName
219+
pure $ Raw.Output op (safeConvert idx)
220+
where
221+
parseName :: String -> (String, Int)
222+
parseName opName =
223+
case break (== ':') opName of
224+
(opName_, ':':idxStr) | idx <- read idxStr
225+
-> (opName_, idx)
226+
_ -> (opName, 0)
227+
228+
resolveOperation :: Raw.Graph -> String -> IO Raw.Operation
229+
resolveOperation graph name = do
230+
op <- Raw.graphOperationByName graph name
231+
case op of
232+
Raw.Operation ptr | ptr == nullPtr -> throwM exception
233+
_ -> pure op
234+
where
235+
exception =
236+
let msg = "Operation not found in graph: " <> (T.pack $ show name)
237+
in TensorFlowException Raw.TF_INVALID_ARGUMENT msg
238+
207239

208240
-- Internal.
209241

@@ -218,21 +250,6 @@ safeConvert x =
218250
show (fromIntegral x :: b)))
219251
(toIntegralSized x)
220252

221-
222-
-- | Use a list of ByteString as a list of CString.
223-
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
224-
withStringList strings fn = go strings []
225-
where
226-
go [] cs = fn (reverse cs)
227-
-- TODO(fmayle): Is it worth using unsafeAsCString here?
228-
go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs)
229-
230-
231-
-- | Use a list of ByteString as an array of CString.
232-
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
233-
withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
234-
235-
236253
-- | Create a Raw.Tensor from a TensorData.
237254
createRawTensor :: TensorData -> IO Raw.Tensor
238255
createRawTensor (TensorData dims dt byteVec) =

tensorflow/src/TensorFlow/Internal/Raw.chs

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ stringGetSize = {# call TF_StringGetSize as ^ #}
6363

6464
-- Operation.
6565
{# pointer *TF_Operation as Operation newtype #}
66+
{# fun TF_OperationName as operationName { `Operation' } -> `String' #}
67+
{# fun TF_OperationNumOutputs as operationNumOutputs { `Operation' } -> `Int' #}
6668

6769
instance Storable Operation where
6870
sizeOf (Operation t) = sizeOf t
@@ -74,18 +76,18 @@ instance Storable Operation where
7476
-- Output.
7577
data Output = Output
7678
{ outputOperation :: Operation
77-
, outputIndex :: CInt
79+
, outputIndex :: Int
7880
}
7981
{# pointer *TF_Output as OutputPtr -> Output #}
8082

8183
instance Storable Output where
8284
sizeOf _ = {# sizeof TF_Output #}
8385
alignment _ = {# alignof TF_Output #}
84-
peek ptr = Output <$> {# get TF_Output->oper #} ptr
85-
<*> {# get TF_Output->index #} ptr
86-
poke ptr (Output oper index) = do
87-
{# set TF_Output->oper #} ptr oper
88-
{# set TF_Output->index #} ptr index
86+
peek p = Output <$> {# get TF_Output->oper #} p
87+
<*> (fromIntegral <$> {# get TF_Output->index #} p)
88+
poke p (Output oper index) = do
89+
{# set TF_Output->oper #} p oper
90+
{# set TF_Output->index #} p $ fromIntegral index
8991

9092

9193
-- Buffer.
@@ -119,6 +121,8 @@ instance Storable Tensor where
119121
-- `CLLong`).
120122
type CInt64 = {#type int64_t #}
121123

124+
{# pointer *size_t as CSizePtr -> CSize #}
125+
122126
newTensor :: DataType
123127
-> Ptr CInt64 -- dimensions array
124128
-> CInt -- num dimensions
@@ -149,27 +153,18 @@ tensorData = {# call TF_TensorData as ^ #}
149153

150154
-- ImportGraphDefOptions.
151155
{# pointer *TF_ImportGraphDefOptions as ImportGraphDefOptions newtype #}
156+
{# fun TF_NewImportGraphDefOptions as newImportGraphDefOptions { } -> `ImportGraphDefOptions' #}
157+
{# fun TF_DeleteImportGraphDefOptions as deleteImportGraphDefOptions { `ImportGraphDefOptions' } -> `()' #}
158+
{# fun TF_ImportGraphDefOptionsAddInputMapping as importGraphDefOptionsAddInputMapping { `ImportGraphDefOptions', `String', `Int', %`OutputPtr' } -> `()' #}
152159

153-
newImportGraphDefOptions :: IO ImportGraphDefOptions
154-
newImportGraphDefOptions = {# call TF_NewImportGraphDefOptions as ^ #}
155-
156-
deleteImportGraphDefOptions :: ImportGraphDefOptions -> IO ()
157-
deleteImportGraphDefOptions = {# call TF_DeleteImportGraphDefOptions as ^ #}
158160

159161
-- Graph.
160162
{# pointer *TF_Graph as Graph newtype #}
161-
162-
newGraph :: IO Graph
163-
newGraph = {# call TF_NewGraph as ^ #}
164-
165-
deleteGraph :: Graph -> IO ()
166-
deleteGraph = {# call TF_DeleteGraph as ^ #}
167-
168-
graphOperationByName :: Graph -> CString -> IO Operation
169-
graphOperationByName = {# call TF_GraphOperationByName as ^ #}
170-
171-
importGraphDef :: Graph -> BufferPtr -> ImportGraphDefOptions -> Status -> IO ()
172-
importGraphDef = {# call TF_GraphImportGraphDef as ^ #}
163+
{# fun TF_NewGraph as newGraph { } -> `Graph' #}
164+
{# fun TF_DeleteGraph as deleteGraph { `Graph' } -> `()' #}
165+
{# fun TF_GraphOperationByName as graphOperationByName { `Graph', `String' } -> `Operation' #}
166+
{# fun TF_GraphNextOperation as graphNextOperation { `Graph', `CSizePtr' } -> `Operation' #}
167+
{# fun TF_GraphImportGraphDef as graphImportGraphDef { `Graph', `BufferPtr', `ImportGraphDefOptions', `Status' } -> `()' #}
173168

174169

175170
-- Session Options.

tensorflow/src/TensorFlow/Session.hs

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,22 +45,24 @@ import Control.Monad.Trans.Class (MonadTrans, lift)
4545
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks)
4646
import Data.ByteString (ByteString)
4747
import Data.Default (Default, def)
48+
import Data.Foldable (for_)
4849
import Data.ProtoLens (showMessage)
4950
import Data.Set (Set)
50-
import Data.Text.Encoding (encodeUtf8)
51+
import Foreign.Marshal.Utils (with)
5152
import Lens.Family2 (Lens', (^.), (&), (.~))
5253
import Lens.Family2.Unchecked (lens)
5354
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
5455
import Proto.Tensorflow.Core.Framework.Graph_Fields (node)
5556
import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
5657
import TensorFlow.Build
5758
import TensorFlow.Nodes
58-
import TensorFlow.Output (NodeName, unNodeName)
59+
import TensorFlow.Output (NodeName(..), unNodeName)
5960
import TensorFlow.Tensor
6061

6162
import qualified Data.ByteString.Builder as Builder
6263
import qualified Data.Map.Strict as Map
6364
import qualified Data.Set as Set
65+
import qualified Data.Text as T
6466
import qualified TensorFlow.Internal.FFI as FFI
6567

6668
-- | An action for logging.
@@ -130,8 +132,8 @@ _runSessionWithOptions :: (MonadMask m, MonadIO m)
130132
-> m a
131133
_runSessionWithOptions (Session m) options withSession =
132134
withSession applyOptions $
133-
\asyncCollector rawSession rawGraph ->
134-
let initState = SessionState rawSession rawGraph asyncCollector (options ^. sessionTracer)
135+
\ac rSession rGraph ->
136+
let initState = SessionState rSession rGraph ac (options ^. sessionTracer)
135137
in evalBuildT (runReaderT m initState)
136138
where
137139
applyOptions opt = do
@@ -154,7 +156,18 @@ extend = do
154156
unless (null nodesToExtend) $ liftIO $ do
155157
let graphDef = (defMessage :: GraphDef) & node .~ nodesToExtend
156158
trace ("Session.extend " <> Builder.string8 (showMessage graphDef))
157-
FFI.importGraphDef graph graphDef
159+
FFI.graphImportGraphDef graph graphDef $ \opts ->
160+
-- All inputs of the nodes in the GraphDef should either refer to
161+
-- other nodes in the GraphDef, or be mapped to nodes already in
162+
-- the Graph by adding an input mapping.
163+
-- We add an input mapping for all existing nodes in the Graph in
164+
-- case they are referenced in the GraphDef.
165+
FFI.forGraphOperations_ graph $ \op -> do
166+
srcName <- FFI.operationName op
167+
numOutputs <- FFI.operationNumOutputs op
168+
for_ [0..numOutputs] $ \srcIndex -> do
169+
let dst = FFI.Output op srcIndex
170+
with dst $ FFI.importGraphDefOptionsAddInputMapping opts srcName srcIndex
158171
-- Now that all the nodes are created, run the initializers.
159172
initializers <- build flushInitializers
160173
unless (null initializers) $
@@ -181,7 +194,7 @@ runFetchWithFeeds :: MonadIO m => [Feed] -> Set NodeName -> Fetch a -> SessionT
181194
runFetchWithFeeds feeds target (Fetch fetch restore) = do
182195
extend
183196
let feeds' = fixFeeds feeds
184-
let fetchNames = encodeUtf8 <$> Set.toList fetch
197+
let fetchNames = T.unpack <$> Set.toList fetch
185198
targetNames = toNodeNames $ Set.toList target
186199
state <- Session ask
187200
runResult <- liftIO $ FFI.run (rawSession state)
@@ -192,8 +205,8 @@ runFetchWithFeeds feeds target (Fetch fetch restore) = do
192205
let resultTensorsMap = Map.fromList $ zip (Set.toList fetch) runResult
193206
return $ restore resultTensorsMap
194207

195-
toNodeNames :: [NodeName] -> [ByteString]
196-
toNodeNames = map (encodeUtf8 . unNodeName)
208+
toNodeNames :: [NodeName] -> [String]
209+
toNodeNames = map (T.unpack . unNodeName)
197210

198211
-- | Run a subgraph 't', rendering and extending any dependent nodes that aren't
199212
-- already rendered. This behaves like 'run' except that it doesn't do any
@@ -210,8 +223,8 @@ runWithFeeds_ feeds t = do
210223
ns <- build $ getNodes t
211224
runFetchWithFeeds feeds ns (pure ())
212225

213-
fixFeeds :: [Feed] -> [(ByteString, FFI.TensorData)]
214-
fixFeeds = map $ \(Feed o d) -> (encodeUtf8 $ encodeOutput o, d)
226+
fixFeeds :: [Feed] -> [(String, FFI.TensorData)]
227+
fixFeeds = map $ \(Feed o d) -> (T.unpack $ encodeOutput o, d)
215228

216229
-- | Starts a concurrent thread which evaluates the given Nodes
217230
-- forever until runSession exits or an exception occurs. Graph

0 commit comments

Comments
 (0)