19
19
module TensorFlow.Internal.FFI
20
20
( TensorFlowException (.. )
21
21
, Raw. Session
22
- , Raw. SessionOptions
23
22
, withSession
23
+ , run
24
+
24
25
, SessionAction
26
+
27
+ , Raw. SessionOptions
28
+
25
29
, Raw. Graph
26
- , importGraphDef
27
- , run
30
+ , Raw. graphOperationByName
31
+ , resolveOutput
32
+ , resolveOperation
33
+ , graphImportGraphDef
34
+ , forGraphOperations_
35
+
36
+ , Raw. Operation
37
+ , Raw. operationName
38
+ , Raw. operationNumOutputs
39
+
40
+ , Raw. ImportGraphDefOptions
41
+ , Raw. importGraphDefOptionsAddInputMapping
42
+
43
+ , Raw. Output (.. )
44
+
28
45
, TensorData (.. )
29
46
, setSessionConfig
30
47
, setSessionTarget
@@ -46,9 +63,8 @@ import Data.Int (Int64)
46
63
import Data.Maybe (fromMaybe )
47
64
import Data.Typeable (Typeable )
48
65
import Data.Word (Word8 )
49
- import Foreign (Ptr , FunPtr , nullPtr , castPtr )
50
- import Foreign.C.String (CString )
51
- import Foreign.ForeignPtr (newForeignPtr , newForeignPtr_ , withForeignPtr )
66
+ import Foreign (Ptr , FunPtr , nullPtr , castPtr , with )
67
+ import Foreign.ForeignPtr (newForeignPtr_ )
52
68
import Foreign.Marshal.Alloc (free )
53
69
import Foreign.Marshal.Array (withArrayLen , peekArray , mallocArray , copyArray )
54
70
import System.IO.Unsafe (unsafePerformIO )
@@ -137,35 +153,50 @@ shutDownRunner r = do
137
153
-- TODO(gnezdo): manage exceptions better than print.
138
154
either print (const (return () )) =<< waitCatch r
139
155
140
- importGraphDef :: Raw. Graph -> GraphDef -> IO ()
141
- importGraphDef graph pb =
156
+ graphImportGraphDef :: Raw. Graph
157
+ -> GraphDef
158
+ -> (Raw. ImportGraphDefOptions -> IO () )
159
+ -> IO ()
160
+ graphImportGraphDef graph pb optionSetter =
142
161
useProtoAsBuffer pb $ \ buffer ->
143
- bracket Raw. newImportGraphDefOptions Raw. deleteImportGraphDefOptions $ \ importGraphDefOptions ->
144
- checkStatus $ Raw. importGraphDef graph buffer importGraphDefOptions
162
+ bracket Raw. newImportGraphDefOptions Raw. deleteImportGraphDefOptions $ \ importGraphDefOptions -> do
163
+ optionSetter importGraphDefOptions
164
+ checkStatus $ Raw. graphImportGraphDef graph buffer importGraphDefOptions
165
+
166
+ forGraphOperations_ :: Raw. Graph
167
+ -> (Raw. Operation -> IO b )
168
+ -> IO ()
169
+ forGraphOperations_ graph f = with 0 go
170
+ where
171
+ go indexPtr = do
172
+ op <- Raw. graphNextOperation graph indexPtr
173
+ case op of
174
+ Raw. Operation ptr | ptr == nullPtr -> return ()
175
+ _ -> f op >> go indexPtr -- indexPtr is modified by Raw.graphNextOperation.
145
176
146
177
run :: Raw. Session
147
178
-> Raw. Graph
148
- -> [(B. ByteString , TensorData )] -- ^ Inputs.
149
- -> [B. ByteString ] -- ^ Outputs.
150
- -> [B. ByteString ] -- ^ Target operations.
179
+ -> [(String , TensorData )] -- ^ Inputs.
180
+ -> [String ] -- ^ Outputs.
181
+ -> [String ] -- ^ Target operations.
151
182
-> IO [TensorData ]
152
183
run session graph inputNamesData outputNames targetNames = do
153
184
-- Use mask to avoid leaking input tensors before they are passed to 'run'
154
185
-- and output tensors before they are passed to 'createTensorData'.
155
186
mask_ $
156
187
-- Inputs.
157
- mapM (resolveOutput . fst ) inputNamesData >>= \ inputs ->
188
+ mapM (resolveOutput graph . fst ) inputNamesData >>= \ inputs ->
158
189
withArrayLen inputs $ \ nInputs cInputs ->
159
190
mapM (createRawTensor . snd ) inputNamesData >>= \ inputTensors ->
160
191
withArrayLen inputTensors $ \ _ cInputTensors ->
161
192
-- Outputs.
162
- mapM resolveOutput outputNames >>= \ outputs ->
193
+ mapM ( resolveOutput graph) outputNames >>= \ outputs ->
163
194
withArrayLen outputs $ \ nOutputs cOutputs ->
164
195
-- outputTensors is an array of null Tensor pointers that will be filled
165
196
-- by the call to Raw.run.
166
197
withArrayLen (replicate nOutputs nullTensor) $ \ _ cOutputTensors ->
167
198
-- Target operations.
168
- mapM resolveOperation targetNames >>= \ targets ->
199
+ mapM ( resolveOperation graph) targetNames >>= \ targets ->
169
200
withArrayLen targets $ \ nTargets cTargets -> do
170
201
checkStatus $ Raw. run
171
202
session
@@ -178,32 +209,33 @@ run session graph inputNamesData outputNames targetNames = do
178
209
outTensors <- peekArray nOutputs cOutputTensors
179
210
mapM createTensorData outTensors
180
211
where
181
- resolveOutput :: B. ByteString -> IO Raw. Output
182
- resolveOutput name = do
183
- let (opName, idx) = parseName name
184
- op <- resolveOperation opName
185
- pure $ Raw. Output op idx
186
-
187
- resolveOperation :: B. ByteString -> IO Raw. Operation
188
- resolveOperation name = do
189
- op <- B. useAsCString name $ Raw. graphOperationByName graph
190
- case op of
191
- Raw. Operation ptr | ptr == nullPtr -> throwM exception
192
- _ -> pure op
193
- where
194
- exception =
195
- let msg = " Operation not found in graph: " <> (T. pack $ show name)
196
- in TensorFlowException Raw. TF_INVALID_ARGUMENT msg
197
-
198
- parseName :: B. ByteString -> (B. ByteString , CInt )
199
- parseName name =
200
- case break (== ' :' ) (C. unpack name) of
201
- (name, ' :' : idxStr) | idx <- read idxStr
202
- -> (C. pack name, idx)
203
- _ -> (name, 0 )
204
212
205
213
nullTensor = Raw. Tensor nullPtr
206
214
215
+ resolveOutput :: Raw. Graph -> String -> IO Raw. Output
216
+ resolveOutput graph name = do
217
+ let (opName, idx) = parseName name
218
+ op <- resolveOperation graph opName
219
+ pure $ Raw. Output op (safeConvert idx)
220
+ where
221
+ parseName :: String -> (String , Int )
222
+ parseName opName =
223
+ case break (== ' :' ) opName of
224
+ (opName_, ' :' : idxStr) | idx <- read idxStr
225
+ -> (opName_, idx)
226
+ _ -> (opName, 0 )
227
+
228
+ resolveOperation :: Raw. Graph -> String -> IO Raw. Operation
229
+ resolveOperation graph name = do
230
+ op <- Raw. graphOperationByName graph name
231
+ case op of
232
+ Raw. Operation ptr | ptr == nullPtr -> throwM exception
233
+ _ -> pure op
234
+ where
235
+ exception =
236
+ let msg = " Operation not found in graph: " <> (T. pack $ show name)
237
+ in TensorFlowException Raw. TF_INVALID_ARGUMENT msg
238
+
207
239
208
240
-- Internal.
209
241
@@ -218,21 +250,6 @@ safeConvert x =
218
250
show (fromIntegral x :: b )))
219
251
(toIntegralSized x)
220
252
221
-
222
- -- | Use a list of ByteString as a list of CString.
223
- withStringList :: [B. ByteString ] -> ([CString ] -> IO a ) -> IO a
224
- withStringList strings fn = go strings []
225
- where
226
- go [] cs = fn (reverse cs)
227
- -- TODO(fmayle): Is it worth using unsafeAsCString here?
228
- go (x: xs) cs = B. useAsCString x $ \ c -> go xs (c: cs)
229
-
230
-
231
- -- | Use a list of ByteString as an array of CString.
232
- withStringArrayLen :: [B. ByteString ] -> (Int -> Ptr CString -> IO a ) -> IO a
233
- withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
234
-
235
-
236
253
-- | Create a Raw.Tensor from a TensorData.
237
254
createRawTensor :: TensorData -> IO Raw. Tensor
238
255
createRawTensor (TensorData dims dt byteVec) =
0 commit comments