Skip to content

Migrate from TF_DeprecatedSession to TF_Session #285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 137 additions & 64 deletions tensorflow/src/TensorFlow/Internal/FFI.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,15 @@ module TensorFlow.Internal.FFI
( TensorFlowException(..)
, Raw.Session
, withSession
, extendGraph
, run

, SessionAction

, Raw.SessionOptions

, Raw.Graph
, extendGraph

, TensorData(..)
, setSessionConfig
, setSessionTarget
Expand All @@ -40,16 +47,17 @@ import Control.Monad.Catch (MonadMask, Exception, throwM, bracket, finally, mask
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Bits (Bits, toIntegralSized)
import Data.Int (Int64)
import Data.Foldable (for_)
import Data.Maybe (fromMaybe)
import Data.Typeable (Typeable)
import Data.Word (Word8)
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
import Foreign.C.String (CString)
import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr)
import Foreign (Ptr, FunPtr, nullPtr, castPtr, with)
import Foreign.ForeignPtr (newForeignPtr_)
import Foreign.Marshal.Alloc (free)
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as C
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T
Expand Down Expand Up @@ -87,15 +95,26 @@ data TensorData = TensorData
}
deriving (Show, Eq)

-- | The action can spawn concurrent tasks which will be canceled before
-- withSession returns.
type SessionAction m a = (IO () -> IO ()) -> Raw.Session -> Raw.Graph -> m a

-- | Runs the given action after creating a session with options
-- populated by the given optionSetter.
withSession :: (MonadIO m, MonadMask m)
=> (Raw.SessionOptions -> IO ())
-> ((IO () -> IO ()) -> Raw.Session -> m a)
-- ^ The action can spawn concurrent tasks which will
-- be canceled before withSession returns.
-> SessionAction m a
-> m a
withSession optionSetter action = do
withSession = withSession_ Raw.newSession

withSession_ :: (MonadIO m, MonadMask m)
=> (Raw.Graph -> Raw.SessionOptions -> Raw.Status -> IO Raw.Session)
-- ^ mkSession
-> (Raw.SessionOptions -> IO ())
-- ^ optionSetter
-> SessionAction m a
-> m a
withSession_ mkSession optionSetter action = do
drain <- liftIO $ newMVar []
let cleanup s =
-- Closes the session to nudge the pending run calls to fail and exit.
Expand All @@ -105,11 +124,12 @@ withSession optionSetter action = do
mapM_ shutDownRunner runners
checkStatus (Raw.deleteSession s)
let bracketIO x y = bracket (liftIO x) (liftIO . y)
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
bracketIO
(optionSetter options >> checkStatus (Raw.newSession options))
cleanup
(action (asyncCollector drain))
bracketIO Raw.newGraph Raw.deleteGraph $ \graph ->
bracketIO Raw.newSessionOptions Raw.deleteSessionOptions $ \options -> do
bracketIO
(optionSetter options >> checkStatus (mkSession graph options))
cleanup
(\session -> action (asyncCollector drain) session graph)

asyncCollector :: MVar [Async ()] -> IO () -> IO ()
asyncCollector drain runner = modifyMVarMasked_ drain launchAndRecord
Expand All @@ -122,43 +142,103 @@ shutDownRunner r = do
-- TODO(gnezdo): manage exceptions better than print.
either print (const (return ())) =<< waitCatch r

extendGraph :: Raw.Session -> GraphDef -> IO ()
extendGraph session pb =
useProtoAsVoidPtrLen pb $ \ptr len ->
checkStatus $ Raw.extendGraph session ptr len

graphImportGraphDef :: Raw.Graph
-> GraphDef
-> (Raw.ImportGraphDefOptions -> IO ())
-> IO ()
graphImportGraphDef graph pb optionSetter =
useProtoAsBuffer pb $ \buffer ->
bracket Raw.newImportGraphDefOptions Raw.deleteImportGraphDefOptions $ \importGraphDefOptions -> do
optionSetter importGraphDefOptions
checkStatus $ Raw.graphImportGraphDef graph buffer importGraphDefOptions

forGraphOperations_ :: Raw.Graph
-> (Raw.Operation -> IO b)
-> IO ()
forGraphOperations_ graph f = with 0 go
where
go indexPtr = do
op <- Raw.graphNextOperation graph indexPtr
case op of
Raw.Operation ptr | ptr == nullPtr -> return ()
_ -> f op >> go indexPtr -- indexPtr is modified by Raw.graphNextOperation.

extendGraph :: Raw.Graph -> GraphDef -> IO ()
extendGraph graph graphDef =
graphImportGraphDef graph graphDef $ \opts ->
-- All inputs of the nodes in the GraphDef should either refer to
-- other nodes in the GraphDef, or be mapped to nodes already in
-- the Graph by adding an input mapping.
-- We add an input mapping for all existing nodes in the Graph in
-- case they are referenced in the GraphDef.
forGraphOperations_ graph $ \op -> do
srcName <- Raw.operationName op
numOutputs <- Raw.operationNumOutputs op
for_ [0..numOutputs] $ \srcIndex -> do
let dst = Raw.Output op (safeConvert srcIndex)
with dst $ Raw.importGraphDefOptionsAddInputMapping opts srcName srcIndex

run :: Raw.Session
-> [(B.ByteString, TensorData)] -- ^ Feeds.
-> [B.ByteString] -- ^ Fetches.
-> [B.ByteString] -- ^ Targets.
-> Raw.Graph
-> [(B.ByteString, TensorData)] -- ^ Inputs.
-> [B.ByteString] -- ^ Outputs.
-> [B.ByteString] -- ^ Target operations.
-> IO [TensorData]
run session feeds fetches targets = do
let nullTensor = Raw.Tensor nullPtr
run session graph inputNamesData outputNames targetNames = do
-- Use mask to avoid leaking input tensors before they are passed to 'run'
-- and output tensors before they are passed to 'createTensorData'.
mask_ $
-- Feeds
withStringArrayLen (fst <$> feeds) $ \feedsLen feedNames ->
mapM (createRawTensor . snd) feeds >>= \feedTensors ->
withArrayLen feedTensors $ \_ cFeedTensors ->
-- Fetches.
withStringArrayLen fetches $ \fetchesLen fetchNames ->
-- tensorOuts is an array of null Tensor pointers that will be filled
-- Inputs.
mapM (resolveOutput graph . fst) inputNamesData >>= \inputs ->
withArrayLen inputs $ \nInputs cInputs ->
mapM (createRawTensor . snd) inputNamesData >>= \inputTensors ->
withArrayLen inputTensors $ \_ cInputTensors ->
-- Outputs.
mapM (resolveOutput graph) outputNames >>= \outputs ->
withArrayLen outputs $ \nOutputs cOutputs ->
-- outputTensors is an array of null Tensor pointers that will be filled
-- by the call to Raw.run.
withArrayLen (replicate fetchesLen nullTensor) $ \_ tensorOuts ->
-- Targets.
withStringArrayLen targets $ \targetsLen ctargets -> do
withArrayLen (replicate nOutputs nullTensor) $ \_ cOutputTensors ->
-- Target operations.
mapM (resolveOperation graph) targetNames >>= \targets ->
withArrayLen targets $ \nTargets cTargets -> do
checkStatus $ Raw.run
session
nullPtr
feedNames cFeedTensors (safeConvert feedsLen)
fetchNames tensorOuts (safeConvert fetchesLen)
ctargets (safeConvert targetsLen)
nullPtr
mapM_ Raw.deleteTensor feedTensors
outTensors <- peekArray fetchesLen tensorOuts
nullPtr -- RunOptions proto.
cInputs cInputTensors (safeConvert nInputs)
cOutputs cOutputTensors (safeConvert nOutputs)
cTargets (safeConvert nTargets)
nullPtr -- RunMetadata.
mapM_ Raw.deleteTensor inputTensors
outTensors <- peekArray nOutputs cOutputTensors
mapM createTensorData outTensors
where

nullTensor = Raw.Tensor nullPtr

resolveOutput :: Raw.Graph -> B.ByteString -> IO Raw.Output
resolveOutput graph name = do
let (opName, idx) = parseName name
op <- resolveOperation graph opName
pure $ Raw.Output op (safeConvert idx)
where
parseName :: B.ByteString -> (B.ByteString, Int)
parseName opName =
case break (== ':') (C.unpack opName) of
(opName_, ':':idxStr) | idx <- read idxStr
-> (C.pack opName_, idx)
_ -> (opName, 0)

resolveOperation :: Raw.Graph -> B.ByteString -> IO Raw.Operation
resolveOperation graph name = do
op <- Raw.graphOperationByName graph name
case op of
Raw.Operation ptr | ptr == nullPtr -> throwM exception
_ -> pure op
where
exception =
let msg = "Operation not found in graph: " <> (T.pack $ show name)
in TensorFlowException Raw.TF_INVALID_ARGUMENT msg


-- Internal.
Expand All @@ -174,21 +254,6 @@ safeConvert x =
show (fromIntegral x :: b)))
(toIntegralSized x)


-- | Use a list of ByteString as a list of CString.
withStringList :: [B.ByteString] -> ([CString] -> IO a) -> IO a
withStringList strings fn = go strings []
where
go [] cs = fn (reverse cs)
-- TODO(fmayle): Is it worth using unsafeAsCString here?
go (x:xs) cs = B.useAsCString x $ \c -> go xs (c:cs)


-- | Use a list of ByteString as an array of CString.
withStringArrayLen :: [B.ByteString] -> (Int -> Ptr CString -> IO a) -> IO a
withStringArrayLen xs fn = withStringList xs (`withArrayLen` fn)


-- | Create a Raw.Tensor from a TensorData.
createRawTensor :: TensorData -> IO Raw.Tensor
createRawTensor (TensorData dims dt byteVec) =
Expand Down Expand Up @@ -258,18 +323,26 @@ useProtoAsVoidPtrLen :: (Message msg, Integral c, Show c, Bits c) =>
useProtoAsVoidPtrLen msg f = B.useAsCStringLen (encodeMessage msg) $
\(bytes, len) -> f (castPtr bytes) (safeConvert len)

-- | Serializes the given msg and provides it as BufferPtr argument
-- to the given action.
useProtoAsBuffer :: (Message msg) =>
msg -> (Raw.BufferPtr -> IO a) -> IO a
useProtoAsBuffer msg f =
B.useAsCStringLen (encodeMessage msg) $ \(bytes, len) ->
bracket (Raw.newBufferFromString (castPtr bytes) (safeConvert len))
Raw.deleteBuffer
f

-- | Returns the serialized OpList of all OpDefs defined in this
-- address space.
getAllOpList :: IO B.ByteString
getAllOpList = do
foreignPtr <-
mask_ (newForeignPtr Raw.deleteBuffer =<< checkCall)
-- Makes a copy because it is more reliable than eviscerating
-- Buffer to steal its memory (including custom deallocator).
withForeignPtr foreignPtr $
\ptr -> B.packCStringLen =<< (,)
<$> (castPtr <$> Raw.getBufferData ptr)
<*> (safeConvert <$> Raw.getBufferLength ptr)
getAllOpList =
bracket checkCall Raw.deleteBuffer $ \buffer ->
-- Makes a copy because it is more reliable than eviscerating
-- Buffer to steal its memory (including custom deallocator).
B.packCStringLen =<< (,)
<$> (castPtr <$> Raw.getBufferData buffer)
<*> (safeConvert <$> Raw.getBufferLength buffer)
where
checkCall = do
p <- Raw.getAllOpList
Expand Down
Loading