Skip to content

Commit b990100

Browse files
author
Bart Schuurmans
committed
Migrate from TF_DeprecatedSession to TF_Session
1 parent cfeb1ca commit b990100

File tree

3 files changed

+201
-87
lines changed

3 files changed

+201
-87
lines changed

tensorflow/src/TensorFlow/Internal/FFI.hs

Lines changed: 98 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,11 @@
1919
module TensorFlow.Internal.FFI
2020
( TensorFlowException(..)
2121
, Raw.Session
22+
, Raw.SessionOptions
2223
, withSession
23-
, extendGraph
24+
, SessionAction
25+
, Raw.Graph
26+
, importGraphDef
2427
, run
2528
, TensorData(..)
2629
, setSessionConfig
@@ -74,15 +77,26 @@ data TensorData = TensorData
7477
}
7578
deriving (Show, Eq)
7679

80+
-- | The action can spawn concurrent tasks which will be canceled before
81+
-- withSession returns.
82+
type SessionAction m a = (IO () -> IO ()) -> Raw.Session -> Raw.Graph -> m a
83+
7784
-- | Runs the given action after creating a session with options
7885
-- populated by the given optionSetter.
7986
withSession :: (MonadIO m, MonadMask m)
8087
=> (Raw.SessionOptions -> IO ())
81-
-> ((IO () -> IO ()) -> Raw.Session -> m a)
82-
-- ^ The action can spawn concurrent tasks which will
83-
-- be canceled before withSession returns.
88+
-> SessionAction m a
8489
-> m a
85-
withSession optionSetter action = do
90+
withSession = withSession_ Raw.newSession
91+
92+
withSession_ :: (MonadIO m, MonadMask m)
93+
=> (Raw.Graph -> Raw.SessionOptions -> Raw.Status -> IO Raw.Session)
94+
-- ^ mkSession
95+
-> (Raw.SessionOptions -> IO ())
96+
-- ^ optionSetter
97+
-> SessionAction m a
98+
-> m a
99+
withSession_ mkSession optionSetter action = do
86100
drain <- liftIO $ newMVar []
87101
let cleanup s =
88102
-- Closes the session to nudge the pending run calls to fail and exit.
@@ -92,11 +106,12 @@ withSession optionSetter action = do
92106
mapM_ shutDownRunner runners
93107
checkStatus (Raw.deleteSession s)
94108
let bracketIO x y = bracket (liftIO x) (liftIO . y)
95-
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
96-
bracketIO
97-
(optionSetter options >> checkStatus (Raw.newSession options))
98-
cleanup
99-
(action (asyncCollector drain))
109+
bracketIO Raw.newGraph Raw.deleteGraph $ \graph ->
110+
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
111+
bracketIO
112+
(optionSetter options >> checkStatus (mkSession graph options))
113+
cleanup
114+
(\session -> action (asyncCollector drain) session graph)
100115

101116
asyncCollector :: MVar [Async ()] -> IO () -> IO ()
102117
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
@@ -109,43 +124,72 @@ shutDownRunner r = do
109124
-- TODO(gnezdo): manage exceptions better than print.
110125
either print (const (return ())) =<< waitCatch r
111126

112-
extendGraph :: Raw.Session -> GraphDef -> IO ()
113-
extendGraph session pb =
114-
useProtoAsVoidPtrLen pb $ \ptr len ->
115-
checkStatus $ Raw.extendGraph session ptr len
116-
127+
importGraphDef :: Raw.Graph -> GraphDef -> IO ()
128+
importGraphDef graph pb =
129+
useProtoAsBuffer pb $ \buffer ->
130+
bracket Raw.newImportGraphDefOptions Raw.deleteImportGraphDefOptions $ \importGraphDefOptions ->
131+
checkStatus $ Raw.importGraphDef graph buffer importGraphDefOptions
117132

118133
run :: Raw.Session
119-
-> [(B.ByteString, TensorData)] -- ^ Feeds.
120-
-> [B.ByteString] -- ^ Fetches.
121-
-> [B.ByteString] -- ^ Targets.
134+
-> Raw.Graph
135+
-> [(B.ByteString, TensorData)] -- ^ Inputs.
136+
-> [B.ByteString] -- ^ Outputs.
137+
-> [B.ByteString] -- ^ Target operations.
122138
-> IO [TensorData]
123-
run session feeds fetches targets = do
124-
let nullTensor = Raw.Tensor nullPtr
139+
run session graph inputNamesData outputNames targetNames = do
125140
-- Use mask to avoid leaking input tensors before they are passed to 'run'
126141
-- and output tensors before they are passed to 'createTensorData'.
127142
mask_ $
128-
-- Feeds
129-
withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames ->
130-
mapM (createRawTensor . snd) feeds >>= \feedTensors ->
131-
withArrayLen feedTensors $ \_ cFeedTensors ->
132-
-- Fetches.
133-
withStringArrayLen fetches $ \fetchesLen fetchNames ->
134-
-- tensorOuts is an array of null Tensor pointers that will be filled
143+
-- Inputs.
144+
mapM (resolveOutput . fst) inputNamesData >>= \inputs ->
145+
withArrayLen inputs $ \nInputs cInputs ->
146+
mapM (createRawTensor . snd) inputNamesData >>= \inputTensors ->
147+
withArrayLen inputTensors $ \_ cInputTensors ->
148+
-- Outputs.
149+
mapM resolveOutput outputNames >>= \outputs ->
150+
withArrayLen outputs $ \nOutputs cOutputs ->
151+
-- outputTensors is an array of null Tensor pointers that will be filled
135152
-- by the call to Raw.run.
136-
withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts ->
137-
-- Targets.
138-
withStringArrayLen targets $ \targetsLen ctargets -> do
153+
withArrayLen (replicate nOutputs nullTensor) $ \_ cOutputTensors ->
154+
-- Target operations.
155+
mapM resolveOperation targetNames >>= \targets ->
156+
withArrayLen targets $ \nTargets cTargets -> do
139157
checkStatus $ Raw.run
140158
session
141-
nullPtr
142-
feedNames cFeedTensors (safeConvert feedsLen)
143-
fetchNames tensorOuts (safeConvert fetchesLen)
144-
ctargets (safeConvert targetsLen)
145-
nullPtr
146-
mapM_ Raw.deleteTensor feedTensors
147-
outTensors <- peekArray fetchesLen tensorOuts
159+
nullPtr -- RunOptions proto.
160+
cInputs cInputTensors (safeConvert nInputs)
161+
cOutputs cOutputTensors (safeConvert nOutputs)
162+
cTargets (safeConvert nTargets)
163+
nullPtr -- RunMetadata.
164+
mapM_ Raw.deleteTensor inputTensors
165+
outTensors <- peekArray nOutputs cOutputTensors
148166
mapM createTensorData outTensors
167+
where
168+
resolveOutput :: B.ByteString -> IO Raw.Output
169+
resolveOutput name = do
170+
let (opName, idx) = parseName name
171+
op <- resolveOperation opName
172+
pure $ Raw.Output op idx
173+
174+
resolveOperation :: B.ByteString -> IO Raw.Operation
175+
resolveOperation name = do
176+
op <- B.useAsCString name $ Raw.graphOperationByName graph
177+
case op of
178+
Raw.Operation ptr | ptr == nullPtr -> throwM exception
179+
_ -> pure op
180+
where
181+
exception =
182+
let msg = "Operation not found in graph: " <> (T.pack $ show name)
183+
in TensorFlowException Raw.TF_INVALID_ARGUMENT msg
184+
185+
parseName :: B.ByteString -> (B.ByteString, CInt)
186+
parseName name =
187+
case break (== ':') (C.unpack name) of
188+
(name, ':':idxStr) | [(idx, "" :: String)] <- read idxStr
189+
-> (C.pack name, fromInteger idx)
190+
_ -> (name, 0)
191+
192+
nullTensor = Raw.Tensor nullPtr
149193

150194

151195
-- Internal.
@@ -245,18 +289,26 @@ useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
245289
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
246290
\(bytes, len) -> f (castPtr bytes) (safeConvert len)
247291

292+
-- | Serializes the given msg and provides it as BufferPtr argument
293+
-- to the given action.
294+
useProtoAsBuffer :: (Message msg) =>
295+
msg -> (Raw.BufferPtr -> IO a) -> IO a
296+
useProtoAsBuffer msg f =
297+
B.useAsCStringLen (encodeMessage msg) $ \(bytes, len) ->
298+
bracket (Raw.newBufferFromString (castPtr bytes) (safeConvert len))
299+
Raw.deleteBuffer
300+
f
301+
248302
-- | Returns the serialized OpList of all OpDefs defined in this
249303
-- address space.
250304
getAllOpList :: IO B.ByteString
251-
getAllOpList = do
252-
foreignPtr <-
253-
mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall)
254-
-- Makes a copy because it is more reliable than eviscerating
255-
-- Buffer to steal its memory (including custom deallocator).
256-
withForeignPtr foreignPtr $
257-
\ptr -> B.packCStringLen =<< (,)
258-
<$> (castPtr <$> Raw.getBufferData ptr)
259-
<*> (safeConvert <$> Raw.getBufferLength ptr)
305+
getAllOpList =
306+
bracket checkCall Raw.deleteBuffer $ \buffer ->
307+
-- Makes a copy because it is more reliable than eviscerating
308+
-- Buffer to steal its memory (including custom deallocator).
309+
B.packCStringLen =<< (,)
310+
<$> (castPtr <$> Raw.getBufferData buffer)
311+
<*> (safeConvert <$> Raw.getBufferLength buffer)
260312
where
261313
checkCall = do
262314
p <- Raw.getAllOpList

tensorflow/src/TensorFlow/Internal/Raw.chs

Lines changed: 69 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,33 @@ message :: Status -> IO CString
4444
message = {# call TF_Message as ^ #}
4545

4646

47+
-- Operation.
48+
{# pointer *TF_Operation as Operation newtype #}
49+
50+
instance Storable Operation where
51+
sizeOf (Operation t) = sizeOf t
52+
alignment (Operation t) = alignment t
53+
peek p = fmap Operation (peek (castPtr p))
54+
poke p (Operation t) = poke (castPtr p) t
55+
56+
57+
-- Output.
58+
data Output = Output
59+
{ outputOperation :: Operation
60+
, outputIndex :: CInt
61+
}
62+
{# pointer *TF_Output as OutputPtr -> Output #}
63+
64+
instance Storable Output where
65+
sizeOf _ = {# sizeof TF_Output #}
66+
alignment _ = {# alignof TF_Output #}
67+
peek ptr = Output <$> {# get TF_Output->oper #} ptr
68+
<*> {# get TF_Output->index #} ptr
69+
poke ptr (Output oper index) = do
70+
{# set TF_Output->oper #} ptr oper
71+
{# set TF_Output->index #} ptr index
72+
73+
4774
-- Buffer.
4875
data Buffer
4976
{# pointer *TF_Buffer as BufferPtr -> Buffer #}
@@ -54,6 +81,12 @@ getBufferData = {# get TF_Buffer->data #}
5481
getBufferLength :: BufferPtr -> IO CULong
5582
getBufferLength = {# get TF_Buffer->length #}
5683

84+
newBufferFromString :: Ptr () -> CULong -> IO BufferPtr
85+
newBufferFromString = {# call TF_NewBufferFromString as ^ #}
86+
87+
deleteBuffer :: BufferPtr -> IO ()
88+
deleteBuffer = {# call TF_DeleteBuffer as ^ #}
89+
5790
-- Tensor.
5891
{# pointer *TF_Tensor as Tensor newtype #}
5992

@@ -97,6 +130,30 @@ tensorByteSize = {# call TF_TensorByteSize as ^ #}
97130
tensorData :: Tensor -> IO (Ptr ())
98131
tensorData = {# call TF_TensorData as ^ #}
99132

133+
-- ImportGraphDefOptions.
134+
{# pointer *TF_ImportGraphDefOptions as ImportGraphDefOptions newtype #}
135+
136+
newImportGraphDefOptions :: IO ImportGraphDefOptions
137+
newImportGraphDefOptions = {# call TF_NewImportGraphDefOptions as ^ #}
138+
139+
deleteImportGraphDefOptions :: ImportGraphDefOptions -> IO ()
140+
deleteImportGraphDefOptions = {# call TF_DeleteImportGraphDefOptions as ^ #}
141+
142+
-- Graph.
143+
{# pointer *TF_Graph as Graph newtype #}
144+
145+
newGraph :: IO Graph
146+
newGraph = {# call TF_NewGraph as ^ #}
147+
148+
deleteGraph :: Graph -> IO ()
149+
deleteGraph = {# call TF_DeleteGraph as ^ #}
150+
151+
graphOperationByName :: Graph -> CString -> IO Operation
152+
graphOperationByName = {# call TF_GraphOperationByName as ^ #}
153+
154+
importGraphDef :: Graph -> BufferPtr -> ImportGraphDefOptions -> Status -> IO ()
155+
importGraphDef = {# call TF_GraphImportGraphDef as ^ #}
156+
100157

101158
-- Session Options.
102159
{# pointer *TF_SessionOptions as SessionOptions newtype #}
@@ -115,29 +172,27 @@ deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #}
115172

116173

117174
-- Session.
118-
{# pointer *TF_DeprecatedSession as Session newtype #}
175+
{# pointer *TF_Session as Session newtype #}
176+
177+
newSession :: Graph -> SessionOptions -> Status -> IO Session
178+
newSession = {# call TF_NewSession as ^ #}
119179

120-
newSession :: SessionOptions -> Status -> IO Session
121-
newSession = {# call TF_NewDeprecatedSession as ^ #}
122180

123181
closeSession :: Session -> Status -> IO ()
124-
closeSession = {# call TF_CloseDeprecatedSession as ^ #}
182+
closeSession = {# call TF_CloseSession as ^ #}
125183

126184
deleteSession :: Session -> Status -> IO ()
127-
deleteSession = {# call TF_DeleteDeprecatedSession as ^ #}
128-
129-
extendGraph :: Session -> Ptr () -> CULong -> Status -> IO ()
130-
extendGraph = {# call TF_ExtendGraph as ^ #}
185+
deleteSession = {# call TF_DeleteSession as ^ #}
131186

132187
run :: Session
133-
-> BufferPtr -- RunOptions proto.
134-
-> Ptr CString -> Ptr Tensor -> CInt -- Input (names, tensors, count).
135-
-> Ptr CString -> Ptr Tensor -> CInt -- Output (names, tensors, count).
136-
-> Ptr CString -> CInt -- Target nodes (names, count).
137-
-> BufferPtr -- RunMetadata proto.
188+
-> BufferPtr -- RunOptions proto.
189+
-> OutputPtr -> Ptr Tensor -> CInt -- Input (names, tensors, count).
190+
-> OutputPtr -> Ptr Tensor -> CInt -- Output (names, tensors, count).
191+
-> Ptr Operation -> CInt -- Target operations (names, count).
192+
-> BufferPtr -- RunMetadata proto.
138193
-> Status
139194
-> IO ()
140-
run = {# call TF_Run as ^ #}
195+
run = {# call TF_SessionRun as ^ #}
141196

142197
-- FFI helpers.
143198
type TensorDeallocFn = Ptr () -> CULong -> Ptr () -> IO ()
@@ -153,6 +208,3 @@ foreign import ccall "wrapper"
153208
-- in this address space.
154209
getAllOpList :: IO BufferPtr
155210
getAllOpList = {# call TF_GetAllOpList as ^ #}
156-
157-
foreign import ccall "&TF_DeleteBuffer"
158-
deleteBuffer :: FunPtr (BufferPtr -> IO ())

0 commit comments

Comments
 (0)