Skip to content

Commit ac91d6b

Browse files
author
Bart Schuurmans
committed
Add support for loading a Session from a SavedModel
1 parent 382b9c1 commit ac91d6b

File tree

4 files changed

+68
-0
lines changed

4 files changed

+68
-0
lines changed

tensorflow/src/TensorFlow/Core.hs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ module TensorFlow.Core
3030
, sessionTracer
3131
, runSession
3232
, runSessionWithOptions
33+
, SavedModelTag(..)
34+
, runSavedModel
35+
, runSavedModelWithOptions
3336
-- ** Building graphs
3437
, MonadBuild(..)
3538
-- ** Running graphs

tensorflow/src/TensorFlow/Internal/FFI.hs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ module TensorFlow.Internal.FFI
2020
( TensorFlowException(..)
2121
, Raw.Session
2222
, withSession
23+
, withSessionFromSavedModel
2324
, run
2425

2526
, SessionAction
@@ -121,6 +122,30 @@ withSession :: (MonadIO m, MonadMask m)
121122
-> m a
122123
withSession = withSession_ Raw.newSession
123124

125+
withSessionFromSavedModel :: (MonadIO m, MonadMask m)
126+
=> B.ByteString
127+
-- ^ exportDir
128+
-> [B.ByteString]
129+
-- ^ Tags.
130+
-> (Raw.SessionOptions -> IO ())
131+
-- ^ optionSetter
132+
-> SessionAction m a
133+
-> m a
134+
withSessionFromSavedModel exportDir tags =
135+
withSession_ $ \graph options status ->
136+
B.useAsCString exportDir $ \cExportDir ->
137+
withStringArrayLen tags $ \nTags cTags ->
138+
Raw.loadSessionFromSavedModel options
139+
runOptions
140+
cExportDir
141+
cTags (safeConvert nTags)
142+
graph
143+
metaGraphDef
144+
status
145+
where
146+
runOptions = nullPtr
147+
metaGraphDef = nullPtr
148+
124149
withSession_ :: (MonadIO m, MonadMask m)
125150
=> (Raw.Graph -> Raw.SessionOptions -> Raw.Status -> IO Raw.Session)
126151
-- ^ mkSession

tensorflow/src/TensorFlow/Internal/Raw.chs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,15 @@ deleteSessionOptions = {# call TF_DeleteSessionOptions as ^ #}
189189
newSession :: Graph -> SessionOptions -> Status -> IO Session
190190
newSession = {# call TF_NewSession as ^ #}
191191

192+
loadSessionFromSavedModel :: SessionOptions
193+
-> BufferPtr -- RunOptions proto.
194+
-> CString -- Export directory.
195+
-> Ptr CString -> CInt -- Tags.
196+
-> Graph -- Graph.
197+
-> BufferPtr -- MetaGraphDef.
198+
-> Status
199+
-> IO Session
200+
loadSessionFromSavedModel = {# call TF_LoadSessionFromSavedModel as ^ #}
192201

193202
closeSession :: Session -> Status -> IO ()
194203
closeSession = {# call TF_CloseSession as ^ #}

tensorflow/src/TensorFlow/Session.hs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,16 @@ module TensorFlow.Session (
2727
sessionTracer,
2828
runSession,
2929
runSessionWithOptions,
30+
runSavedModel,
31+
runSavedModelWithOptions,
3032
MonadBuild(..),
3133
addGraphDef,
3234
run,
3335
runWithFeeds,
3436
run_,
3537
runWithFeeds_,
3638
asyncProdNodes,
39+
SavedModelTag(..),
3740
) where
3841

3942
import Data.ProtoLens.Message(defMessage)
@@ -61,6 +64,7 @@ import TensorFlow.Nodes
6164
import TensorFlow.Output (NodeName(..), unNodeName, output)
6265
import TensorFlow.Tensor
6366

67+
import qualified Data.ByteString.Char8 as C
6468
import qualified Data.ByteString.Builder as Builder
6569
import qualified Data.ByteString.Char8 as C
6670
import qualified Data.Map.Strict as Map
@@ -98,6 +102,14 @@ data Options = Options
98102
, _sessionTracer :: Tracer
99103
}
100104

105+
data SavedModelTag = GPU | TPU | Serve | Train
106+
107+
savedModelTagValue :: SavedModelTag -> ByteString
108+
savedModelTagValue GPU = "gpu"
109+
savedModelTagValue TPU = "tpu"
110+
savedModelTagValue Serve = "serve"
111+
savedModelTagValue Train = "train"
112+
101113
instance Default Options where
102114
def = Options
103115
{ _sessionTarget = ""
@@ -128,6 +140,25 @@ runSessionWithOptions :: (MonadMask m, MonadIO m) => Options -> SessionT m a ->
128140
runSessionWithOptions options session =
129141
_runSessionWithOptions session options $ FFI.withSession
130142

143+
runSavedModel :: (MonadMask m, MonadIO m)
144+
=> FilePath
145+
-- ^ Export directory.
146+
-> Set SavedModelTag
147+
-> SessionT m a
148+
-> m a
149+
runSavedModel exportDir tags = runSavedModelWithOptions exportDir tags def
150+
151+
runSavedModelWithOptions :: (MonadMask m, MonadIO m)
152+
=> FilePath
153+
-- ^ Export directory.
154+
-> Set SavedModelTag
155+
-> Options
156+
-> SessionT m a
157+
-> m a
158+
runSavedModelWithOptions exportDir tags options session =
159+
_runSessionWithOptions session options $
160+
FFI.withSessionFromSavedModel (C.pack exportDir) (map savedModelTagValue $ Set.toList tags)
161+
131162
_runSessionWithOptions :: (MonadMask m, MonadIO m)
132163
=> SessionT m a
133164
-> Options

0 commit comments

Comments
 (0)