Skip to content

Support loading a Session from a SavedModel #286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorflow/src/TensorFlow/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ module TensorFlow.Core
, sessionTracer
, runSession
, runSessionWithOptions
, SavedModelTag(..)
, runSavedModel
, runSavedModelWithOptions
-- ** Building graphs
, MonadBuild(..)
-- ** Running graphs
Expand Down
23 changes: 23 additions & 0 deletions tensorflow/src/TensorFlow/Internal/FFI.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ module TensorFlow.Internal.FFI
( TensorFlowException(..)
, Raw.Session
, withSession
, withSessionFromSavedModel
, run

, SessionAction
Expand Down Expand Up @@ -107,6 +108,28 @@ withSession :: (MonadIO m, MonadMask m)
-> m a
withSession = withSession_ Raw.newSession

withSessionFromSavedModel :: (MonadIO m, MonadMask m)
=> B.ByteString
-- ^ exportDir
-> [B.ByteString]
-- ^ Tags.
-> (Raw.SessionOptions -> IO ())
-- ^ optionSetter
-> SessionAction m a
-> m a
withSessionFromSavedModel exportDir tags =
withSession_ $ \graph options status ->
Raw.loadSessionFromSavedModel options
runOptions
exportDir
tags
graph
metaGraphDef
status
where
runOptions = nullPtr
metaGraphDef = nullPtr

withSession_ :: (MonadIO m, MonadMask m)
=> (Raw.Graph -> Raw.SessionOptions -> Raw.Status -> IO Raw.Session)
-- ^ mkSession
Expand Down
25 changes: 25 additions & 0 deletions tensorflow/src/TensorFlow/Internal/Raw.chs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,16 @@ deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #}
newSession :: Graph -> SessionOptions -> Status -> IO Session
newSession = {# call TF_NewSession as ^ #}

{# fun TF_LoadSessionFromSavedModel as loadSessionFromSavedModel
{ `SessionOptions'
, `BufferPtr' -- RunOptions proto.
, useAsCString* `ByteString' -- Export directory.
, withStringArrayLen* `[ByteString]'& -- Tags.
, `Graph'
, `BufferPtr' -- MetaGraphDef.
, `Status'
} -> `Session'
#}

closeSession :: Session -> Status -> IO ()
closeSession = {# call TF_CloseSession as ^ #}
Expand Down Expand Up @@ -231,3 +241,18 @@ foreign import ccall "wrapper"
-- in this address space.
getAllOpList :: IO BufferPtr
getAllOpList = {# call TF_GetAllOpList as ^ #}

-- | Use a list of ByteString as a list of CString.
withStringList :: [ByteString] -> ([CString] -> IO a) -> IO a
withStringList strings fn = go strings []
where
go [] cs = fn (reverse cs)
-- TODO(fmayle): Is it worth using unsafeAsCString here?
go (x:xs) cs = useAsCString x $ \c -> go xs (c:cs)


-- | Use a list of ByteString as an array of CString with its length.
withStringArrayLen :: [ByteString] -> ((Ptr CString, CInt) -> IO a) -> IO a
withStringArrayLen xs fn =
withStringList xs $ \strings ->
withArrayLen strings $ \len ptr -> fn (ptr, fromIntegral len)
31 changes: 31 additions & 0 deletions tensorflow/src/TensorFlow/Session.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ module TensorFlow.Session (
sessionTracer,
runSession,
runSessionWithOptions,
runSavedModel,
runSavedModelWithOptions,
MonadBuild(..),
extend,
addGraphDef,
Expand All @@ -35,6 +37,7 @@ module TensorFlow.Session (
run_,
runWithFeeds_,
asyncProdNodes,
SavedModelTag(..),
) where

import Data.ProtoLens.Message(defMessage)
Expand All @@ -58,6 +61,7 @@ import TensorFlow.Nodes
import TensorFlow.Output (NodeName(..), unNodeName)
import TensorFlow.Tensor

import qualified Data.ByteString.Char8 as C
import qualified Data.ByteString.Builder as Builder
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
Expand Down Expand Up @@ -97,6 +101,14 @@ data Options = Options
, _sessionTracer :: Tracer
}

data SavedModelTag = GPU | TPU | Serve | Train

savedModelTagValue :: SavedModelTag -> ByteString
savedModelTagValue GPU = "gpu"
savedModelTagValue TPU = "tpu"
savedModelTagValue Serve = "serve"
savedModelTagValue Train = "train"

instance Default Options where
def = Options
{ _sessionTarget = ""
Expand All @@ -123,6 +135,25 @@ runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a ->
runSessionWithOptions options session =
_runSessionWithOptions session options $ FFI.withSession

runSavedModel :: (MonadMask m, MonadIO m)
=> FilePath
-- ^ Export directory.
-> Set SavedModelTag
-> SessionT m a
-> m a
runSavedModel exportDir tags = runSavedModelWithOptions exportDir tags def

runSavedModelWithOptions :: (MonadMask m, MonadIO m)
=> FilePath
-- ^ Export directory.
-> Set SavedModelTag
-> Options
-> SessionT m a
-> m a
runSavedModelWithOptions exportDir tags options session =
_runSessionWithOptions session options $
FFI.withSessionFromSavedModel (C.pack exportDir) (map savedModelTagValue $ Set.toList tags)

_runSessionWithOptions :: (MonadMask m, MonadIO m)
=> SessionT m a
-> Options
Expand Down