@@ -20,8 +20,15 @@ module TensorFlow.Internal.FFI
20
20
( TensorFlowException (.. )
21
21
, Raw. Session
22
22
, withSession
23
- , extendGraph
24
23
, run
24
+
25
+ , SessionAction
26
+
27
+ , Raw. SessionOptions
28
+
29
+ , Raw. Graph
30
+ , extendGraph
31
+
25
32
, TensorData (.. )
26
33
, setSessionConfig
27
34
, setSessionTarget
@@ -40,16 +47,17 @@ import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask
40
47
import Control.Monad.IO.Class (MonadIO , liftIO )
41
48
import Data.Bits (Bits , toIntegralSized )
42
49
import Data.Int (Int64 )
50
+ import Data.Foldable (for_ )
43
51
import Data.Maybe (fromMaybe )
44
52
import Data.Typeable (Typeable )
45
53
import Data.Word (Word8 )
46
- import Foreign (Ptr , FunPtr , nullPtr , castPtr )
47
- import Foreign.C.String (CString )
48
- import Foreign.ForeignPtr (newForeignPtr , newForeignPtr_ , withForeignPtr )
54
+ import Foreign (Ptr , FunPtr , nullPtr , castPtr , with )
55
+ import Foreign.ForeignPtr (newForeignPtr_ )
49
56
import Foreign.Marshal.Alloc (free )
50
57
import Foreign.Marshal.Array (withArrayLen , peekArray , mallocArray , copyArray )
51
58
import System.IO.Unsafe (unsafePerformIO )
52
59
import qualified Data.ByteString as B
60
+ import qualified Data.ByteString.Char8 as C
53
61
import qualified Data.Text as T
54
62
import qualified Data.Text.Encoding as T
55
63
import qualified Data.Text.Encoding.Error as T
@@ -87,15 +95,26 @@ data TensorData = TensorData
87
95
}
88
96
deriving (Show , Eq )
89
97
98
+ -- | The action can spawn concurrent tasks which will be canceled before
99
+ -- withSession returns.
100
+ type SessionAction m a = (IO () -> IO () ) -> Raw. Session -> Raw. Graph -> m a
101
+
90
102
-- | Runs the given action after creating a session with options
91
103
-- populated by the given optionSetter.
92
104
withSession :: (MonadIO m , MonadMask m )
93
105
=> (Raw. SessionOptions -> IO () )
94
- -> ((IO () -> IO () ) -> Raw. Session -> m a )
95
- -- ^ The action can spawn concurrent tasks which will
96
- -- be canceled before withSession returns.
106
+ -> SessionAction m a
97
107
-> m a
98
- withSession optionSetter action = do
108
+ withSession = withSession_ Raw. newSession
109
+
110
+ withSession_ :: (MonadIO m , MonadMask m )
111
+ => (Raw. Graph -> Raw. SessionOptions -> Raw. Status -> IO Raw. Session )
112
+ -- ^ mkSession
113
+ -> (Raw. SessionOptions -> IO () )
114
+ -- ^ optionSetter
115
+ -> SessionAction m a
116
+ -> m a
117
+ withSession_ mkSession optionSetter action = do
99
118
drain <- liftIO $ newMVar []
100
119
let cleanup s =
101
120
-- Closes the session to nudge the pending run calls to fail and exit.
@@ -105,11 +124,12 @@ withSession optionSetter action = do
105
124
mapM_ shutDownRunner runners
106
125
checkStatus (Raw. deleteSession s)
107
126
let bracketIO x y = bracket (liftIO x) (liftIO . y)
108
- bracketIO Raw. newSessionOptions Raw. deleteSessionOptions $ \ options -> do
109
- bracketIO
110
- (optionSetter options >> checkStatus (Raw. newSession options))
111
- cleanup
112
- (action (asyncCollector drain))
127
+ bracketIO Raw. newGraph Raw. deleteGraph $ \ graph ->
128
+ bracketIO Raw. newSessionOptions Raw. deleteSessionOptions $ \ options -> do
129
+ bracketIO
130
+ (optionSetter options >> checkStatus (mkSession graph options))
131
+ cleanup
132
+ (\ session -> action (asyncCollector drain) session graph)
113
133
114
134
asyncCollector :: MVar [Async () ] -> IO () -> IO ()
115
135
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
@@ -122,43 +142,103 @@ shutDownRunner r = do
122
142
-- TODO(gnezdo): manage exceptions better than print.
123
143
either print (const (return () )) =<< waitCatch r
124
144
125
- extendGraph :: Raw. Session -> GraphDef -> IO ()
126
- extendGraph session pb =
127
- useProtoAsVoidPtrLen pb $ \ ptr len ->
128
- checkStatus $ Raw. extendGraph session ptr len
129
-
145
+ graphImportGraphDef :: Raw. Graph
146
+ -> GraphDef
147
+ -> (Raw. ImportGraphDefOptions -> IO () )
148
+ -> IO ()
149
+ graphImportGraphDef graph pb optionSetter =
150
+ useProtoAsBuffer pb $ \ buffer ->
151
+ bracket Raw. newImportGraphDefOptions Raw. deleteImportGraphDefOptions $ \ importGraphDefOptions -> do
152
+ optionSetter importGraphDefOptions
153
+ checkStatus $ Raw. graphImportGraphDef graph buffer importGraphDefOptions
154
+
155
+ forGraphOperations_ :: Raw. Graph
156
+ -> (Raw. Operation -> IO b )
157
+ -> IO ()
158
+ forGraphOperations_ graph f = with 0 go
159
+ where
160
+ go indexPtr = do
161
+ op <- Raw. graphNextOperation graph indexPtr
162
+ case op of
163
+ Raw. Operation ptr | ptr == nullPtr -> return ()
164
+ _ -> f op >> go indexPtr -- indexPtr is modified by Raw.graphNextOperation.
165
+
166
+ extendGraph :: Raw. Graph -> GraphDef -> IO ()
167
+ extendGraph graph graphDef =
168
+ graphImportGraphDef graph graphDef $ \ opts ->
169
+ -- All inputs of the nodes in the GraphDef should either refer to
170
+ -- other nodes in the GraphDef, or be mapped to nodes already in
171
+ -- the Graph by adding an input mapping.
172
+ -- We add an input mapping for all existing nodes in the Graph in
173
+ -- case they are referenced in the GraphDef.
174
+ forGraphOperations_ graph $ \ op -> do
175
+ srcName <- Raw. operationName op
176
+ numOutputs <- Raw. operationNumOutputs op
177
+ for_ [0 .. numOutputs] $ \ srcIndex -> do
178
+ let dst = Raw. Output op (safeConvert srcIndex)
179
+ with dst $ Raw. importGraphDefOptionsAddInputMapping opts srcName srcIndex
130
180
131
181
run :: Raw. Session
132
- -> [(B. ByteString , TensorData )] -- ^ Feeds.
133
- -> [B. ByteString ] -- ^ Fetches.
134
- -> [B. ByteString ] -- ^ Targets.
182
+ -> Raw. Graph
183
+ -> [(B. ByteString , TensorData )] -- ^ Inputs.
184
+ -> [B. ByteString ] -- ^ Outputs.
185
+ -> [B. ByteString ] -- ^ Target operations.
135
186
-> IO [TensorData ]
136
- run session feeds fetches targets = do
137
- let nullTensor = Raw. Tensor nullPtr
187
+ run session graph inputNamesData outputNames targetNames = do
138
188
-- Use mask to avoid leaking input tensors before they are passed to 'run'
139
189
-- and output tensors before they are passed to 'createTensorData'.
140
190
mask_ $
141
- -- Feeds
142
- withStringArrayLen (fst <$> feeds) $ \ feedsLen feedNames ->
143
- mapM (createRawTensor . snd ) feeds >>= \ feedTensors ->
144
- withArrayLen feedTensors $ \ _ cFeedTensors ->
145
- -- Fetches.
146
- withStringArrayLen fetches $ \ fetchesLen fetchNames ->
147
- -- tensorOuts is an array of null Tensor pointers that will be filled
191
+ -- Inputs.
192
+ mapM (resolveOutput graph . fst ) inputNamesData >>= \ inputs ->
193
+ withArrayLen inputs $ \ nInputs cInputs ->
194
+ mapM (createRawTensor . snd ) inputNamesData >>= \ inputTensors ->
195
+ withArrayLen inputTensors $ \ _ cInputTensors ->
196
+ -- Outputs.
197
+ mapM (resolveOutput graph) outputNames >>= \ outputs ->
198
+ withArrayLen outputs $ \ nOutputs cOutputs ->
199
+ -- outputTensors is an array of null Tensor pointers that will be filled
148
200
-- by the call to Raw.run.
149
- withArrayLen (replicate fetchesLen nullTensor) $ \ _ tensorOuts ->
150
- -- Targets.
151
- withStringArrayLen targets $ \ targetsLen ctargets -> do
201
+ withArrayLen (replicate nOutputs nullTensor) $ \ _ cOutputTensors ->
202
+ -- Target operations.
203
+ mapM (resolveOperation graph) targetNames >>= \ targets ->
204
+ withArrayLen targets $ \ nTargets cTargets -> do
152
205
checkStatus $ Raw. run
153
206
session
154
- nullPtr
155
- feedNames cFeedTensors (safeConvert feedsLen )
156
- fetchNames tensorOuts (safeConvert fetchesLen )
157
- ctargets (safeConvert targetsLen )
158
- nullPtr
159
- mapM_ Raw. deleteTensor feedTensors
160
- outTensors <- peekArray fetchesLen tensorOuts
207
+ nullPtr -- RunOptions proto.
208
+ cInputs cInputTensors (safeConvert nInputs )
209
+ cOutputs cOutputTensors (safeConvert nOutputs )
210
+ cTargets (safeConvert nTargets )
211
+ nullPtr -- RunMetadata.
212
+ mapM_ Raw. deleteTensor inputTensors
213
+ outTensors <- peekArray nOutputs cOutputTensors
161
214
mapM createTensorData outTensors
215
+ where
216
+
217
+ nullTensor = Raw. Tensor nullPtr
218
+
219
+ resolveOutput :: Raw. Graph -> B. ByteString -> IO Raw. Output
220
+ resolveOutput graph name = do
221
+ let (opName, idx) = parseName name
222
+ op <- resolveOperation graph opName
223
+ pure $ Raw. Output op (safeConvert idx)
224
+ where
225
+ parseName :: B. ByteString -> (B. ByteString , Int )
226
+ parseName opName =
227
+ case break (== ' :' ) (C. unpack opName) of
228
+ (opName_, ' :' : idxStr) | idx <- read idxStr
229
+ -> (C. pack opName_, idx)
230
+ _ -> (opName, 0 )
231
+
232
+ resolveOperation :: Raw. Graph -> B. ByteString -> IO Raw. Operation
233
+ resolveOperation graph name = do
234
+ op <- Raw. graphOperationByName graph name
235
+ case op of
236
+ Raw. Operation ptr | ptr == nullPtr -> throwM exception
237
+ _ -> pure op
238
+ where
239
+ exception =
240
+ let msg = " Operation not found in graph: " <> (T. pack $ show name)
241
+ in TensorFlowException Raw. TF_INVALID_ARGUMENT msg
162
242
163
243
164
244
-- Internal.
@@ -174,21 +254,6 @@ safeConvert x =
174
254
show (fromIntegral x :: b )))
175
255
(toIntegralSized x)
176
256
177
-
178
- -- | Use a list of ByteString as a list of CString.
179
- withStringList :: [B. ByteString ] -> ([CString ] -> IO a ) -> IO a
180
- withStringList strings fn = go strings []
181
- where
182
- go [] cs = fn (reverse cs)
183
- -- TODO(fmayle): Is it worth using unsafeAsCString here?
184
- go (x: xs) cs = B. useAsCString x $ \ c -> go xs (c: cs)
185
-
186
-
187
- -- | Use a list of ByteString as an array of CString.
188
- withStringArrayLen :: [B. ByteString ] -> (Int -> Ptr CString -> IO a ) -> IO a
189
- withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
190
-
191
-
192
257
-- | Create a Raw.Tensor from a TensorData.
193
258
createRawTensor :: TensorData -> IO Raw. Tensor
194
259
createRawTensor (TensorData dims dt byteVec) =
@@ -258,18 +323,26 @@ useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
258
323
useProtoAsVoidPtrLen msg f = B. useAsCStringLen (encodeMessage msg) $
259
324
\ (bytes, len) -> f (castPtr bytes) (safeConvert len)
260
325
326
+ -- | Serializes the given msg and provides it as BufferPtr argument
327
+ -- to the given action.
328
+ useProtoAsBuffer :: (Message msg ) =>
329
+ msg -> (Raw. BufferPtr -> IO a ) -> IO a
330
+ useProtoAsBuffer msg f =
331
+ B. useAsCStringLen (encodeMessage msg) $ \ (bytes, len) ->
332
+ bracket (Raw. newBufferFromString (castPtr bytes) (safeConvert len))
333
+ Raw. deleteBuffer
334
+ f
335
+
261
336
-- | Returns the serialized OpList of all OpDefs defined in this
262
337
-- address space.
263
338
getAllOpList :: IO B. ByteString
264
- getAllOpList = do
265
- foreignPtr <-
266
- mask_ (newForeignPtr Raw. deleteBuffer =<< checkCall)
267
- -- Makes a copy because it is more reliable than eviscerating
268
- -- Buffer to steal its memory (including custom deallocator).
269
- withForeignPtr foreignPtr $
270
- \ ptr -> B. packCStringLen =<< (,)
271
- <$> (castPtr <$> Raw. getBufferData ptr)
272
- <*> (safeConvert <$> Raw. getBufferLength ptr)
339
+ getAllOpList =
340
+ bracket checkCall Raw. deleteBuffer $ \ buffer ->
341
+ -- Makes a copy because it is more reliable than eviscerating
342
+ -- Buffer to steal its memory (including custom deallocator).
343
+ B. packCStringLen =<< (,)
344
+ <$> (castPtr <$> Raw. getBufferData buffer)
345
+ <*> (safeConvert <$> Raw. getBufferLength buffer)
273
346
where
274
347
checkCall = do
275
348
p <- Raw. getAllOpList
0 commit comments