1
- # TODO: Stop using v1 compatibility
2
- import tensorflow .compat .v1 as tf
1
+ import tensorflow as tf
3
2
3
+ # check tensorflow version is 2.x
4
+ tf_major_version = tf .__version__ .split ('.' )[0 ]
5
+ assert tf_major_version == '2'
4
6
5
- tf .disable_eager_execution ()
6
- x = tf .placeholder (tf .int32 , name = 'x' )
7
- y = tf .placeholder (tf .int32 , name = 'y' )
8
- z = tf .add (x , y , name = 'z' )
7
+ @tf .function
8
+ def add (x , y ):
9
+ return tf .add (x , y )
9
10
10
- tf .variables_initializer (tf .global_variables (), name = 'init' )
11
+ x = tf .TensorSpec ((), dtype = tf .dtypes .int32 , name = 'x' )
12
+ y = tf .TensorSpec ((), dtype = tf .dtypes .int32 , name = 'y' )
11
13
12
- definition = tf . Session (). graph_def
14
+ concrete_function = add . get_concrete_function ( x , y )
13
15
directory = 'examples/addition'
14
- tf .train .write_graph (definition , directory , 'model.pb' , as_text = False )
16
+ tf .io .write_graph (concrete_function .graph , directory , 'model.pb' , as_text = False )
17
+
18
+ # check inputs/outputs node names to refer from Rust later on
19
+ print (f'input nodes : { concrete_function .inputs } ' )
20
+ print (f'output nodes : { concrete_function .outputs } ' )
0 commit comments