@@ -27,20 +27,7 @@ module TensorFlow.Internal.FFI
27
27
, Raw. SessionOptions
28
28
29
29
, Raw. Graph
30
- , Raw. graphOperationByName
31
- , resolveOutput
32
- , resolveOperation
33
- , graphImportGraphDef
34
- , forGraphOperations_
35
-
36
- , Raw. Operation
37
- , Raw. operationName
38
- , Raw. operationNumOutputs
39
-
40
- , Raw. ImportGraphDefOptions
41
- , Raw. importGraphDefOptionsAddInputMapping
42
-
43
- , Raw. Output (.. )
30
+ , extendGraph
44
31
45
32
, TensorData (.. )
46
33
, setSessionConfig
@@ -60,6 +47,7 @@ import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask
60
47
import Control.Monad.IO.Class (MonadIO , liftIO )
61
48
import Data.Bits (Bits , toIntegralSized )
62
49
import Data.Int (Int64 )
50
+ import Data.Foldable (for_ )
63
51
import Data.Maybe (fromMaybe )
64
52
import Data.Typeable (Typeable )
65
53
import Data.Word (Word8 )
@@ -174,6 +162,21 @@ forGraphOperations_ graph f = with 0 go
174
162
Raw. Operation ptr | ptr == nullPtr -> return ()
175
163
_ -> f op >> go indexPtr -- indexPtr is modified by Raw.graphNextOperation.
176
164
165
+ extendGraph :: Raw. Graph -> GraphDef -> IO ()
166
+ extendGraph graph graphDef =
167
+ graphImportGraphDef graph graphDef $ \ opts ->
168
+ -- All inputs of the nodes in the GraphDef should either refer to
169
+ -- other nodes in the GraphDef, or be mapped to nodes already in
170
+ -- the Graph by adding an input mapping.
171
+ -- We add an input mapping for all existing nodes in the Graph in
172
+ -- case they are referenced in the GraphDef.
173
+ forGraphOperations_ graph $ \ op -> do
174
+ srcName <- Raw. operationName op
175
+ numOutputs <- Raw. operationNumOutputs op
176
+ for_ [0 .. numOutputs] $ \ srcIndex -> do
177
+ let dst = Raw. Output op (safeConvert srcIndex)
178
+ with dst $ Raw. importGraphDefOptionsAddInputMapping opts srcName srcIndex
179
+
177
180
run :: Raw. Session
178
181
-> Raw. Graph
179
182
-> [(String , TensorData )] -- ^ Inputs.
0 commit comments