19
19
module TensorFlow.Internal.FFI
20
20
( TensorFlowException (.. )
21
21
, Raw. Session
22
+ , Raw. SessionOptions
22
23
, withSession
23
- , extendGraph
24
+ , SessionAction
25
+ , Raw. Graph
26
+ , importGraphDef
24
27
, run
25
28
, TensorData (.. )
26
29
, setSessionConfig
@@ -74,15 +77,26 @@ data TensorData = TensorData
74
77
}
75
78
deriving (Show , Eq )
76
79
80
+ -- | The action can spawn concurrent tasks which will be canceled before
81
+ -- withSession returns.
82
+ type SessionAction m a = (IO () -> IO () ) -> Raw. Session -> Raw. Graph -> m a
83
+
77
84
-- | Runs the given action after creating a session with options
78
85
-- populated by the given optionSetter.
79
86
withSession :: (MonadIO m , MonadMask m )
80
87
=> (Raw. SessionOptions -> IO () )
81
- -> ((IO () -> IO () ) -> Raw. Session -> m a )
82
- -- ^ The action can spawn concurrent tasks which will
83
- -- be canceled before withSession returns.
88
+ -> SessionAction m a
84
89
-> m a
85
- withSession optionSetter action = do
90
+ withSession = withSession_ Raw. newSession
91
+
92
+ withSession_ :: (MonadIO m , MonadMask m )
93
+ => (Raw. Graph -> Raw. SessionOptions -> Raw. Status -> IO Raw. Session )
94
+ -- ^ mkSession
95
+ -> (Raw. SessionOptions -> IO () )
96
+ -- ^ optionSetter
97
+ -> SessionAction m a
98
+ -> m a
99
+ withSession_ mkSession optionSetter action = do
86
100
drain <- liftIO $ newMVar []
87
101
let cleanup s =
88
102
-- Closes the session to nudge the pending run calls to fail and exit.
@@ -92,11 +106,12 @@ withSession optionSetter action = do
92
106
mapM_ shutDownRunner runners
93
107
checkStatus (Raw. deleteSession s)
94
108
let bracketIO x y = bracket (liftIO x) (liftIO . y)
95
- bracketIO Raw. newSessionOptions Raw. deleteSessionOptions $ \ options -> do
96
- bracketIO
97
- (optionSetter options >> checkStatus (Raw. newSession options))
98
- cleanup
99
- (action (asyncCollector drain))
109
+ bracketIO Raw. newGraph Raw. deleteGraph $ \ graph ->
110
+ bracketIO Raw. newSessionOptions Raw. deleteSessionOptions $ \ options -> do
111
+ bracketIO
112
+ (optionSetter options >> checkStatus (mkSession graph options))
113
+ cleanup
114
+ (\ session -> action (asyncCollector drain) session graph)
100
115
101
116
asyncCollector :: MVar [Async () ] -> IO () -> IO ()
102
117
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
@@ -109,43 +124,72 @@ shutDownRunner r = do
109
124
-- TODO(gnezdo): manage exceptions better than print.
110
125
either print (const (return () )) =<< waitCatch r
111
126
112
- extendGraph :: Raw. Session -> GraphDef -> IO ()
113
- extendGraph session pb =
114
- useProtoAsVoidPtrLen pb $ \ ptr len ->
115
- checkStatus $ Raw. extendGraph session ptr len
116
-
127
+ importGraphDef :: Raw. Graph -> GraphDef -> IO ()
128
+ importGraphDef graph pb =
129
+ useProtoAsBuffer pb $ \ buffer ->
130
+ bracket Raw. newImportGraphDefOptions Raw. deleteImportGraphDefOptions $ \ importGraphDefOptions ->
131
+ checkStatus $ Raw. importGraphDef graph buffer importGraphDefOptions
117
132
118
133
run :: Raw. Session
119
- -> [(B. ByteString , TensorData )] -- ^ Feeds.
120
- -> [B. ByteString ] -- ^ Fetches.
121
- -> [B. ByteString ] -- ^ Targets.
134
+ -> Raw. Graph
135
+ -> [(B. ByteString , TensorData )] -- ^ Inputs.
136
+ -> [B. ByteString ] -- ^ Outputs.
137
+ -> [B. ByteString ] -- ^ Target operations.
122
138
-> IO [TensorData ]
123
- run session feeds fetches targets = do
124
- let nullTensor = Raw. Tensor nullPtr
139
+ run session graph inputNamesData outputNames targetNames = do
125
140
-- Use mask to avoid leaking input tensors before they are passed to 'run'
126
141
-- and output tensors before they are passed to 'createTensorData'.
127
142
mask_ $
128
- -- Feeds
129
- withStringArrayLen (fst <$> feeds) $ \ feedsLen feedNames ->
130
- mapM (createRawTensor . snd ) feeds >>= \ feedTensors ->
131
- withArrayLen feedTensors $ \ _ cFeedTensors ->
132
- -- Fetches.
133
- withStringArrayLen fetches $ \ fetchesLen fetchNames ->
134
- -- tensorOuts is an array of null Tensor pointers that will be filled
143
+ -- Inputs.
144
+ mapM (resolveOutput . fst ) inputNamesData >>= \ inputs ->
145
+ withArrayLen inputs $ \ nInputs cInputs ->
146
+ mapM (createRawTensor . snd ) inputNamesData >>= \ inputTensors ->
147
+ withArrayLen inputTensors $ \ _ cInputTensors ->
148
+ -- Outputs.
149
+ mapM resolveOutput outputNames >>= \ outputs ->
150
+ withArrayLen outputs $ \ nOutputs cOutputs ->
151
+ -- outputTensors is an array of null Tensor pointers that will be filled
135
152
-- by the call to Raw.run.
136
- withArrayLen (replicate fetchesLen nullTensor) $ \ _ tensorOuts ->
137
- -- Targets.
138
- withStringArrayLen targets $ \ targetsLen ctargets -> do
153
+ withArrayLen (replicate nOutputs nullTensor) $ \ _ cOutputTensors ->
154
+ -- Target operations.
155
+ mapM resolveOperation targetNames >>= \ targets ->
156
+ withArrayLen targets $ \ nTargets cTargets -> do
139
157
checkStatus $ Raw. run
140
158
session
141
- nullPtr
142
- feedNames cFeedTensors (safeConvert feedsLen )
143
- fetchNames tensorOuts (safeConvert fetchesLen )
144
- ctargets (safeConvert targetsLen )
145
- nullPtr
146
- mapM_ Raw. deleteTensor feedTensors
147
- outTensors <- peekArray fetchesLen tensorOuts
159
+ nullPtr -- RunOptions proto.
160
+ cInputs cInputTensors (safeConvert nInputs )
161
+ cOutputs cOutputTensors (safeConvert nOutputs )
162
+ cTargets (safeConvert nTargets )
163
+ nullPtr -- RunMetadata.
164
+ mapM_ Raw. deleteTensor inputTensors
165
+ outTensors <- peekArray nOutputs cOutputTensors
148
166
mapM createTensorData outTensors
167
+ where
168
+ resolveOutput :: B. ByteString -> IO Raw. Output
169
+ resolveOutput name = do
170
+ let (opName, idx) = parseName name
171
+ op <- resolveOperation opName
172
+ pure $ Raw. Output op idx
173
+
174
+ resolveOperation :: B. ByteString -> IO Raw. Operation
175
+ resolveOperation name = do
176
+ op <- B. useAsCString name $ Raw. graphOperationByName graph
177
+ case op of
178
+ Raw. Operation ptr | ptr == nullPtr -> throwM exception
179
+ _ -> pure op
180
+ where
181
+ exception =
182
+ let msg = " Operation not found in graph: " <> (T. pack $ show name)
183
+ in TensorFlowException Raw. TF_INVALID_ARGUMENT msg
184
+
185
+ parseName :: B. ByteString -> (B. ByteString , CInt )
186
+ parseName name =
187
+ case break (== ' :' ) (C. unpack name) of
188
+ (name, ' :' : idxStr) | [(idx, " " :: String )] <- read idxStr
189
+ -> (C. pack name, fromInteger idx)
190
+ _ -> (name, 0 )
191
+
192
+ nullTensor = Raw. Tensor nullPtr
149
193
150
194
151
195
-- Internal.
@@ -245,18 +289,26 @@ useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
245
289
useProtoAsVoidPtrLen msg f = B. useAsCStringLen (encodeMessage msg) $
246
290
\ (bytes, len) -> f (castPtr bytes) (safeConvert len)
247
291
292
+ -- | Serializes the given msg and provides it as BufferPtr argument
293
+ -- to the given action.
294
+ useProtoAsBuffer :: (Message msg ) =>
295
+ msg -> (Raw. BufferPtr -> IO a ) -> IO a
296
+ useProtoAsBuffer msg f =
297
+ B. useAsCStringLen (encodeMessage msg) $ \ (bytes, len) ->
298
+ bracket (Raw. newBufferFromString (castPtr bytes) (safeConvert len))
299
+ Raw. deleteBuffer
300
+ f
301
+
248
302
-- | Returns the serialized OpList of all OpDefs defined in this
249
303
-- address space.
250
304
getAllOpList :: IO B. ByteString
251
- getAllOpList = do
252
- foreignPtr <-
253
- mask_ (newForeignPtr Raw. deleteBuffer =<< checkCall)
254
- -- Makes a copy because it is more reliable than eviscerating
255
- -- Buffer to steal its memory (including custom deallocator).
256
- withForeignPtr foreignPtr $
257
- \ ptr -> B. packCStringLen =<< (,)
258
- <$> (castPtr <$> Raw. getBufferData ptr)
259
- <*> (safeConvert <$> Raw. getBufferLength ptr)
305
+ getAllOpList =
306
+ bracket checkCall Raw. deleteBuffer $ \ buffer ->
307
+ -- Makes a copy because it is more reliable than eviscerating
308
+ -- Buffer to steal its memory (including custom deallocator).
309
+ B. packCStringLen =<< (,)
310
+ <$> (castPtr <$> Raw. getBufferData buffer)
311
+ <*> (safeConvert <$> Raw. getBufferLength buffer)
260
312
where
261
313
checkCall = do
262
314
p <- Raw. getAllOpList
0 commit comments