Skip to content

Commit 12ac796

Browse files
author
Bart Schuurmans
committed
Add support for loading a Session from a SavedModel
1 parent 8ec68b1 commit 12ac796

File tree

4 files changed

+70
-0
lines changed

4 files changed

+70
-0
lines changed

tensorflow/src/TensorFlow/Core.hs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ module TensorFlow.Core
3030
, sessionTracer
3131
, runSession
3232
, runSessionWithOptions
33+
, SavedModelTag(..)
34+
, runSavedModel
35+
, runSavedModelWithOptions
3336
-- ** Building graphs
3437
, MonadBuild(..)
3538
-- ** Running graphs

tensorflow/src/TensorFlow/Internal/FFI.hs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ module TensorFlow.Internal.FFI
2121
, Raw.Session
2222
, Raw.SessionOptions
2323
, withSession
24+
, withSessionFromSavedModel
2425
, SessionAction
2526
, Raw.Graph
2627
, importGraphDef
@@ -45,12 +46,14 @@ import Data.Maybe (fromMaybe)
4546
import Data.Typeable (Typeable)
4647
import Data.Word (Word8)
4748
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
49+
import Foreign.C (CInt)
4850
import Foreign.C.String (CString)
4951
import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr)
5052
import Foreign.Marshal.Alloc (free)
5153
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
5254
import System.IO.Unsafe (unsafePerformIO)
5355
import qualified Data.ByteString as B
56+
import qualified Data.ByteString.Char8 as C
5457
import qualified Data.Text as T
5558
import qualified Data.Text.Encoding as T
5659
import qualified Data.Text.Encoding.Error as T
@@ -89,6 +92,30 @@ withSession :: (MonadIO m, MonadMask m)
8992
-> m a
9093
withSession = withSession_ Raw.newSession
9194

95+
withSessionFromSavedModel :: (MonadIO m, MonadMask m)
96+
=> B.ByteString
97+
-- ^ exportDir
98+
-> [B.ByteString]
99+
-- ^ Tags.
100+
-> (Raw.SessionOptions -> IO ())
101+
-- ^ optionSetter
102+
-> SessionAction m a
103+
-> m a
104+
withSessionFromSavedModel exportDir tags =
105+
withSession_ $ \graph options status ->
106+
B.useAsCString exportDir $ \cExportDir ->
107+
withStringArrayLen tags $ \nTags cTags ->
108+
Raw.loadSessionFromSavedModel options
109+
runOptions
110+
cExportDir
111+
cTags (safeConvert nTags)
112+
graph
113+
metaGraphDef
114+
status
115+
where
116+
runOptions = nullPtr
117+
metaGraphDef = nullPtr
118+
92119
withSession_ :: (MonadIO m, MonadMask m)
93120
=> (Raw.Graph -> Raw.SessionOptions -> Raw.Status -> IO Raw.Session)
94121
-- ^ mkSession

tensorflow/src/TensorFlow/Internal/Raw.chs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,15 @@ deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #}
177177
newSession :: Graph -> SessionOptions -> Status -> IO Session
178178
newSession = {# call TF_NewSession as ^ #}
179179

180+
loadSessionFromSavedModel :: SessionOptions
181+
-> BufferPtr -- RunOptions proto.
182+
-> CString -- Export directory.
183+
-> Ptr CString -> CInt -- Tags.
184+
-> Graph -- Graph.
185+
-> BufferPtr -- MetaGraphDef.
186+
-> Status
187+
-> IO Session
188+
loadSessionFromSavedModel = {# call TF_LoadSessionFromSavedModel as ^ #}
180189

181190
closeSession :: Session -> Status -> IO ()
182191
closeSession = {# call TF_CloseSession as ^ #}

tensorflow/src/TensorFlow/Session.hs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,16 @@ module TensorFlow.Session (
2727
sessionTracer,
2828
runSession,
2929
runSessionWithOptions,
30+
runSavedModel,
31+
runSavedModelWithOptions,
3032
MonadBuild(..),
3133
addGraphDef,
3234
run,
3335
runWithFeeds,
3436
run_,
3537
runWithFeeds_,
3638
asyncProdNodes,
39+
SavedModelTag(..),
3740
) where
3841

3942
import Data.ProtoLens.Message(defMessage)
@@ -57,6 +60,7 @@ import TensorFlow.Nodes
5760
import TensorFlow.Output (NodeName, unNodeName)
5861
import TensorFlow.Tensor
5962

63+
import qualified Data.ByteString.Char8 as C
6064
import qualified Data.ByteString.Builder as Builder
6165
import qualified Data.Map.Strict as Map
6266
import qualified Data.Set as Set
@@ -92,6 +96,14 @@ data Options = Options
9296
, _sessionTracer :: Tracer
9397
}
9498

99+
data SavedModelTag = GPU | TPU | Serve | Train
100+
101+
savedModelTagValue :: SavedModelTag -> ByteString
102+
savedModelTagValue GPU = "gpu"
103+
savedModelTagValue TPU = "tpu"
104+
savedModelTagValue Serve = "serve"
105+
savedModelTagValue Train = "train"
106+
95107
instance Default Options where
96108
def = Options
97109
{ _sessionTarget = ""
@@ -122,6 +134,25 @@ runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a ->
122134
runSessionWithOptions options session =
123135
_runSessionWithOptions session options $ FFI.withSession
124136

137+
runSavedModel :: (MonadMask m, MonadIO m)
138+
=> FilePath
139+
-- ^ Export directory.
140+
-> Set SavedModelTag
141+
-> SessionT m a
142+
-> m a
143+
runSavedModel exportDir tags = runSavedModelWithOptions exportDir tags def
144+
145+
runSavedModelWithOptions :: (MonadMask m, MonadIO m)
146+
=> FilePath
147+
-- ^ Export directory.
148+
-> Set SavedModelTag
149+
-> Options
150+
-> SessionT m a
151+
-> m a
152+
runSavedModelWithOptions exportDir tags options session =
153+
_runSessionWithOptions session options $
154+
FFI.withSessionFromSavedModel (C.pack exportDir) (map savedModelTagValue $ Set.toList tags)
155+
125156
_runSessionWithOptions :: (MonadMask m, MonadIO m)
126157
=> SessionT m a
127158
-> Options

0 commit comments

Comments
 (0)