-
Notifications
You must be signed in to change notification settings - Fork 0
/
export_inference_graph.py
60 lines (54 loc) · 2.53 KB
/
export_inference_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import tensorflow as tf
from google.protobuf import text_format
import sys
sys.path.append("models/research/")
sys.path.append("models/research/slim/")
from object_detection import exporter
from object_detection.protos import pipeline_pb2
slim = tf.contrib.slim
flags = tf.app.flags
flags.DEFINE_string('input_type', 'image_tensor', 'Type of input node. Can be '
'one of [`image_tensor`, `encoded_image_string_tensor`, '
'`tf_example`]')
flags.DEFINE_string('input_shape', None,
'If input_type is `image_tensor`, this can explicitly set '
'the shape of this input tensor to a fixed size. The '
'dimensions are to be provided as a comma-separated list '
'of integers. A value of -1 can be used for unknown '
'dimensions. If not specified, for an `image_tensor, the '
'default shape will be partially specified as '
'`[None, None, None, 3]`.')
flags.DEFINE_string('pipeline_config_path', None,
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file.')
flags.DEFINE_string('trained_checkpoint_prefix', None,
'Path to trained checkpoint, typically of the form '
'path/to/model.ckpt')
flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
flags.DEFINE_string('config_override', '',
'pipeline_pb2.TrainEvalPipelineConfig '
'text proto to override pipeline_config_path.')
flags.DEFINE_boolean('write_inference_graph', False,
'If true, writes inference graph to disk.')
tf.app.flags.mark_flag_as_required('pipeline_config_path')
tf.app.flags.mark_flag_as_required('trained_checkpoint_prefix')
tf.app.flags.mark_flag_as_required('output_directory')
FLAGS = flags.FLAGS
def main(_):
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
text_format.Merge(f.read(), pipeline_config)
text_format.Merge(FLAGS.config_override, pipeline_config)
if FLAGS.input_shape:
input_shape = [
int(dim) if dim != '-1' else None
for dim in FLAGS.input_shape.split(',')
]
else:
input_shape = None
exporter.export_inference_graph(
FLAGS.input_type, pipeline_config, FLAGS.trained_checkpoint_prefix,
FLAGS.output_directory, input_shape=input_shape,
write_inference_graph=FLAGS.write_inference_graph)
if __name__ == '__main__':
tf.app.run()