Skip to content

Commit 80ce0d3

Browse files
author
Bart Schuurmans
committed
Migrate from TF_DeprecatedSession to TF_Session
Instead of calling TF_ExtendGraph, we call TF_GraphImportGraphDef and pass an input map for all existing nodes in the graph.
1 parent 30a12d7 commit 80ce0d3

File tree

4 files changed

+245
-103
lines changed

4 files changed

+245
-103
lines changed

tensorflow/src/TensorFlow/Internal/FFI.hs

Lines changed: 137 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,15 @@ module TensorFlow.Internal.FFI
2020
( TensorFlowException(..)
2121
, Raw.Session
2222
, withSession
23-
, extendGraph
2423
, run
24+
25+
, SessionAction
26+
27+
, Raw.SessionOptions
28+
29+
, Raw.Graph
30+
, extendGraph
31+
2532
, TensorData(..)
2633
, setSessionConfig
2734
, setSessionTarget
@@ -40,16 +47,17 @@ import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask
4047
import Control.Monad.IO.Class (MonadIO, liftIO)
4148
import Data.Bits (Bits, toIntegralSized)
4249
import Data.Int (Int64)
50+
import Data.Foldable (for_)
4351
import Data.Maybe (fromMaybe)
4452
import Data.Typeable (Typeable)
4553
import Data.Word (Word8)
46-
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
47-
import Foreign.C.String (CString)
48-
import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr)
54+
import Foreign (Ptr, FunPtr, nullPtr, castPtr, with)
55+
import Foreign.ForeignPtr (newForeignPtr_)
4956
import Foreign.Marshal.Alloc (free)
5057
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
5158
import System.IO.Unsafe (unsafePerformIO)
5259
import qualified Data.ByteString as B
60+
import qualified Data.ByteString.Char8 as C
5361
import qualified Data.Text as T
5462
import qualified Data.Text.Encoding as T
5563
import qualified Data.Text.Encoding.Error as T
@@ -87,15 +95,26 @@ data TensorData = TensorData
8795
}
8896
deriving (Show, Eq)
8997

98+
-- | The action can spawn concurrent tasks which will be canceled before
99+
-- withSession returns.
100+
type SessionAction m a = (IO () -> IO ()) -> Raw.Session -> Raw.Graph -> m a
101+
90102
-- | Runs the given action after creating a session with options
91103
-- populated by the given optionSetter.
92104
withSession :: (MonadIO m, MonadMask m)
93105
=> (Raw.SessionOptions -> IO ())
94-
-> ((IO () -> IO ()) -> Raw.Session -> m a)
95-
-- ^ The action can spawn concurrent tasks which will
96-
-- be canceled before withSession returns.
106+
-> SessionAction m a
97107
-> m a
98-
withSession optionSetter action = do
108+
withSession = withSession_ Raw.newSession
109+
110+
withSession_ :: (MonadIO m, MonadMask m)
111+
=> (Raw.Graph -> Raw.SessionOptions -> Raw.Status -> IO Raw.Session)
112+
-- ^ mkSession
113+
-> (Raw.SessionOptions -> IO ())
114+
-- ^ optionSetter
115+
-> SessionAction m a
116+
-> m a
117+
withSession_ mkSession optionSetter action = do
99118
drain <- liftIO $ newMVar []
100119
let cleanup s =
101120
-- Closes the session to nudge the pending run calls to fail and exit.
@@ -105,11 +124,12 @@ withSession optionSetter action = do
105124
mapM_ shutDownRunner runners
106125
checkStatus (Raw.deleteSession s)
107126
let bracketIO x y = bracket (liftIO x) (liftIO . y)
108-
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
109-
bracketIO
110-
(optionSetter options >> checkStatus (Raw.newSession options))
111-
cleanup
112-
(action (asyncCollector drain))
127+
bracketIO Raw.newGraph Raw.deleteGraph $ \graph ->
128+
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
129+
bracketIO
130+
(optionSetter options >> checkStatus (mkSession graph options))
131+
cleanup
132+
(\session -> action (asyncCollector drain) session graph)
113133

114134
asyncCollector :: MVar [Async ()] -> IO () -> IO ()
115135
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
@@ -122,43 +142,103 @@ shutDownRunner r = do
122142
-- TODO(gnezdo): manage exceptions better than print.
123143
either print (const (return ())) =<< waitCatch r
124144

125-
extendGraph :: Raw.Session -> GraphDef -> IO ()
126-
extendGraph session pb =
127-
useProtoAsVoidPtrLen pb $ \ptr len ->
128-
checkStatus $ Raw.extendGraph session ptr len
129-
145+
graphImportGraphDef :: Raw.Graph
146+
-> GraphDef
147+
-> (Raw.ImportGraphDefOptions -> IO ())
148+
-> IO ()
149+
graphImportGraphDef graph pb optionSetter =
150+
useProtoAsBuffer pb $ \buffer ->
151+
bracket Raw.newImportGraphDefOptions Raw.deleteImportGraphDefOptions $ \importGraphDefOptions -> do
152+
optionSetter importGraphDefOptions
153+
checkStatus $ Raw.graphImportGraphDef graph buffer importGraphDefOptions
154+
155+
forGraphOperations_ :: Raw.Graph
156+
-> (Raw.Operation -> IO b)
157+
-> IO ()
158+
forGraphOperations_ graph f = with 0 go
159+
where
160+
go indexPtr = do
161+
op <- Raw.graphNextOperation graph indexPtr
162+
case op of
163+
Raw.Operation ptr | ptr == nullPtr -> return ()
164+
_ -> f op >> go indexPtr -- indexPtr is modified by Raw.graphNextOperation.
165+
166+
extendGraph :: Raw.Graph -> GraphDef -> IO ()
167+
extendGraph graph graphDef =
168+
graphImportGraphDef graph graphDef $ \opts ->
169+
-- All inputs of the nodes in the GraphDef should either refer to
170+
-- other nodes in the GraphDef, or be mapped to nodes already in
171+
-- the Graph by adding an input mapping.
172+
-- We add an input mapping for all existing nodes in the Graph in
173+
-- case they are referenced in the GraphDef.
174+
forGraphOperations_ graph $ \op -> do
175+
srcName <- Raw.operationName op
176+
numOutputs <- Raw.operationNumOutputs op
177+
for_ [0..numOutputs] $ \srcIndex -> do
178+
let dst = Raw.Output op (safeConvert srcIndex)
179+
with dst $ Raw.importGraphDefOptionsAddInputMapping opts srcName srcIndex
130180

131181
run :: Raw.Session
132-
-> [(B.ByteString, TensorData)] -- ^ Feeds.
133-
-> [B.ByteString] -- ^ Fetches.
134-
-> [B.ByteString] -- ^ Targets.
182+
-> Raw.Graph
183+
-> [(B.ByteString, TensorData)] -- ^ Inputs.
184+
-> [B.ByteString] -- ^ Outputs.
185+
-> [B.ByteString] -- ^ Target operations.
135186
-> IO [TensorData]
136-
run session feeds fetches targets = do
137-
let nullTensor = Raw.Tensor nullPtr
187+
run session graph inputNamesData outputNames targetNames = do
138188
-- Use mask to avoid leaking input tensors before they are passed to 'run'
139189
-- and output tensors before they are passed to 'createTensorData'.
140190
mask_ $
141-
-- Feeds
142-
withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames ->
143-
mapM (createRawTensor . snd) feeds >>= \feedTensors ->
144-
withArrayLen feedTensors $ \_ cFeedTensors ->
145-
-- Fetches.
146-
withStringArrayLen fetches $ \fetchesLen fetchNames ->
147-
-- tensorOuts is an array of null Tensor pointers that will be filled
191+
-- Inputs.
192+
mapM (resolveOutput graph . fst) inputNamesData >>= \inputs ->
193+
withArrayLen inputs $ \nInputs cInputs ->
194+
mapM (createRawTensor . snd) inputNamesData >>= \inputTensors ->
195+
withArrayLen inputTensors $ \_ cInputTensors ->
196+
-- Outputs.
197+
mapM (resolveOutput graph) outputNames >>= \outputs ->
198+
withArrayLen outputs $ \nOutputs cOutputs ->
199+
-- outputTensors is an array of null Tensor pointers that will be filled
148200
-- by the call to Raw.run.
149-
withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts ->
150-
-- Targets.
151-
withStringArrayLen targets $ \targetsLen ctargets -> do
201+
withArrayLen (replicate nOutputs nullTensor) $ \_ cOutputTensors ->
202+
-- Target operations.
203+
mapM (resolveOperation graph) targetNames >>= \targets ->
204+
withArrayLen targets $ \nTargets cTargets -> do
152205
checkStatus $ Raw.run
153206
session
154-
nullPtr
155-
feedNames cFeedTensors (safeConvert feedsLen)
156-
fetchNames tensorOuts (safeConvert fetchesLen)
157-
ctargets (safeConvert targetsLen)
158-
nullPtr
159-
mapM_ Raw.deleteTensor feedTensors
160-
outTensors <- peekArray fetchesLen tensorOuts
207+
nullPtr -- RunOptions proto.
208+
cInputs cInputTensors (safeConvert nInputs)
209+
cOutputs cOutputTensors (safeConvert nOutputs)
210+
cTargets (safeConvert nTargets)
211+
nullPtr -- RunMetadata.
212+
mapM_ Raw.deleteTensor inputTensors
213+
outTensors <- peekArray nOutputs cOutputTensors
161214
mapM createTensorData outTensors
215+
where
216+
217+
nullTensor = Raw.Tensor nullPtr
218+
219+
resolveOutput :: Raw.Graph -> B.ByteString -> IO Raw.Output
220+
resolveOutput graph name = do
221+
let (opName, idx) = parseName name
222+
op <- resolveOperation graph opName
223+
pure $ Raw.Output op (safeConvert idx)
224+
where
225+
parseName :: B.ByteString -> (B.ByteString, Int)
226+
parseName opName =
227+
case break (== ':') (C.unpack opName) of
228+
(opName_, ':':idxStr) | idx <- read idxStr
229+
-> (C.pack opName_, idx)
230+
_ -> (opName, 0)
231+
232+
resolveOperation :: Raw.Graph -> B.ByteString -> IO Raw.Operation
233+
resolveOperation graph name = do
234+
op <- Raw.graphOperationByName graph name
235+
case op of
236+
Raw.Operation ptr | ptr == nullPtr -> throwM exception
237+
_ -> pure op
238+
where
239+
exception =
240+
let msg = "Operation not found in graph: " <> (T.pack $ show name)
241+
in TensorFlowException Raw.TF_INVALID_ARGUMENT msg
162242

163243

164244
-- Internal.
@@ -174,21 +254,6 @@ safeConvert x =
174254
show (fromIntegral x :: b)))
175255
(toIntegralSized x)
176256

177-
178-
-- | Use a list of ByteString as a list of CString.
179-
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
180-
withStringList strings fn = go strings []
181-
where
182-
go [] cs = fn (reverse cs)
183-
-- TODO(fmayle): Is it worth using unsafeAsCString here?
184-
go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs)
185-
186-
187-
-- | Use a list of ByteString as an array of CString.
188-
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
189-
withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)
190-
191-
192257
-- | Create a Raw.Tensor from a TensorData.
193258
createRawTensor :: TensorData -> IO Raw.Tensor
194259
createRawTensor (TensorData dims dt byteVec) =
@@ -258,18 +323,26 @@ useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
258323
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
259324
\(bytes, len) -> f (castPtr bytes) (safeConvert len)
260325

326+
-- | Serializes the given msg and provides it as BufferPtr argument
327+
-- to the given action.
328+
useProtoAsBuffer :: (Message msg) =>
329+
msg -> (Raw.BufferPtr -> IO a) -> IO a
330+
useProtoAsBuffer msg f =
331+
B.useAsCStringLen (encodeMessage msg) $ \(bytes, len) ->
332+
bracket (Raw.newBufferFromString (castPtr bytes) (safeConvert len))
333+
Raw.deleteBuffer
334+
f
335+
261336
-- | Returns the serialized OpList of all OpDefs defined in this
262337
-- address space.
263338
getAllOpList :: IO B.ByteString
264-
getAllOpList = do
265-
foreignPtr <-
266-
mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall)
267-
-- Makes a copy because it is more reliable than eviscerating
268-
-- Buffer to steal its memory (including custom deallocator).
269-
withForeignPtr foreignPtr $
270-
\ptr -> B.packCStringLen =<< (,)
271-
<$> (castPtr <$> Raw.getBufferData ptr)
272-
<*> (safeConvert <$> Raw.getBufferLength ptr)
339+
getAllOpList =
340+
bracket checkCall Raw.deleteBuffer $ \buffer ->
341+
-- Makes a copy because it is more reliable than eviscerating
342+
-- Buffer to steal its memory (including custom deallocator).
343+
B.packCStringLen =<< (,)
344+
<$> (castPtr <$> Raw.getBufferData buffer)
345+
<*> (safeConvert <$> Raw.getBufferLength buffer)
273346
where
274347
checkCall = do
275348
p <- Raw.getAllOpList

0 commit comments

Comments
 (0)