22
22
from __future__ import division
23
23
from __future__ import print_function
24
24
25
- import os
26
25
from absl import flags
27
26
import tensorflow as tf
28
27
29
28
from tensorflow .contrib .tpu .python .tpu import tpu_config
30
- from tensorflow .contrib .training .python .training import evaluation
31
29
32
30
from object_detection import model_hparams
33
31
from object_detection import model_lib
48
46
flags .DEFINE_string (
49
47
'tpu_name' ,
50
48
default = None ,
51
- help = 'Name of the Cloud TPU for Cluster Resolvers. You must specify either '
52
- 'this flag or --master.' )
53
-
54
- flags .DEFINE_string (
55
- 'master' ,
56
- default = None ,
57
- help = 'GRPC URL of the master (e.g. grpc://ip.address.of.tpu:8470). You '
58
- 'must specify either this flag or --tpu_name.' )
49
+ help = 'Name of the Cloud TPU for Cluster Resolvers.' )
59
50
60
51
flags .DEFINE_integer ('num_shards' , 8 , 'Number of shards (TPU cores).' )
61
52
flags .DEFINE_integer ('iterations_per_loop' , 100 ,
62
53
'Number of iterations per TPU training loop.' )
63
54
# For mode=train_and_eval, evaluation occurs after training is finished.
64
55
# Note: independently of steps_per_checkpoint, estimator will save the most
65
56
# recent checkpoint every 10 minutes by default for train_and_eval
66
- flags .DEFINE_string ('mode' , 'train_and_eval ' ,
67
- 'Mode to run: train, eval, train_and_eval ' )
57
+ flags .DEFINE_string ('mode' , 'train ' ,
58
+ 'Mode to run: train, eval' )
68
59
flags .DEFINE_integer ('train_batch_size' , 32 * 8 , 'Batch size for training.' )
69
60
70
- # For EVAL.
71
- flags .DEFINE_integer ('min_eval_interval_secs' , 180 ,
72
- 'Minimum seconds between evaluations.' )
73
- flags .DEFINE_integer (
74
- 'eval_timeout_secs' , None ,
75
- 'Maximum seconds between checkpoints before evaluation terminates.' )
76
61
flags .DEFINE_string (
77
62
'hparams_overrides' , None , 'Comma-separated list of '
78
63
'hyperparameters to override defaults.' )
@@ -93,21 +78,12 @@ def main(unused_argv):
93
78
flags .mark_flag_as_required ('model_dir' )
94
79
flags .mark_flag_as_required ('pipeline_config_path' )
95
80
96
- if FLAGS .master is None and FLAGS .tpu_name is None :
97
- raise RuntimeError ('You must specify either --master or --tpu_name.' )
98
-
99
- if FLAGS .master is not None :
100
- if FLAGS .tpu_name is not None :
101
- tf .logging .warn ('Both --master and --tpu_name are set. Ignoring '
102
- '--tpu_name and using --master.' )
103
- tpu_grpc_url = FLAGS .master
104
- else :
105
- tpu_cluster_resolver = (
106
- tf .contrib .cluster_resolver .python .training .TPUClusterResolver (
107
- tpu_names = [FLAGS .tpu_name ],
108
- zone = FLAGS .tpu_zone ,
109
- project = FLAGS .gcp_project ))
110
- tpu_grpc_url = tpu_cluster_resolver .get_master ()
81
+ tpu_cluster_resolver = (
82
+ tf .contrib .cluster_resolver .python .training .TPUClusterResolver (
83
+ tpu_names = [FLAGS .tpu_name ],
84
+ zone = FLAGS .tpu_zone ,
85
+ project = FLAGS .gcp_project ))
86
+ tpu_grpc_url = tpu_cluster_resolver .get_master ()
111
87
112
88
config = tpu_config .RunConfig (
113
89
master = tpu_grpc_url ,
@@ -134,53 +110,19 @@ def main(unused_argv):
134
110
train_steps = train_and_eval_dict ['train_steps' ]
135
111
eval_steps = train_and_eval_dict ['eval_steps' ]
136
112
137
- if FLAGS .mode in [ 'train' , 'train_and_eval' ] :
113
+ if FLAGS .mode == 'train' :
138
114
estimator .train (input_fn = train_input_fn , max_steps = train_steps )
139
115
140
- if FLAGS .mode == 'train_and_eval' :
141
- # Eval one time.
142
- eval_results = estimator .evaluate (input_fn = eval_input_fn , steps = eval_steps )
143
- tf .logging .info ('Eval results: %s' % eval_results )
144
-
145
116
# Continuously evaluating.
146
117
if FLAGS .mode == 'eval' :
147
- def terminate_eval ():
148
- tf .logging .info ('Terminating eval after %d seconds of no checkpoints' %
149
- FLAGS .eval_timeout_secs )
150
- return True
151
-
152
- # Run evaluation when there's a new checkpoint.
153
- for ckpt in evaluation .checkpoints_iterator (
154
- FLAGS .model_dir ,
155
- min_interval_secs = FLAGS .min_eval_interval_secs ,
156
- timeout = FLAGS .eval_timeout_secs ,
157
- timeout_fn = terminate_eval ):
158
-
159
- tf .logging .info ('Starting to evaluate.' )
160
- if FLAGS .eval_training_data :
161
- name = 'training_data'
162
- input_fn = eval_on_train_input_fn
163
- else :
164
- name = 'validation_data'
165
- input_fn = eval_input_fn
166
- try :
167
- eval_results = estimator .evaluate (
168
- input_fn = input_fn ,
169
- steps = eval_steps ,
170
- checkpoint_path = ckpt ,
171
- name = name )
172
- tf .logging .info ('Eval results: %s' % eval_results )
173
-
174
- # Terminate eval job when final checkpoint is reached
175
- current_step = int (os .path .basename (ckpt ).split ('-' )[1 ])
176
- if current_step >= train_steps :
177
- tf .logging .info (
178
- 'Evaluation finished after training step %d' % current_step )
179
- break
180
-
181
- except tf .errors .NotFoundError :
182
- tf .logging .info (
183
- 'Checkpoint %s no longer exists, skipping checkpoint' % ckpt )
118
+ if FLAGS .eval_training_data :
119
+ name = 'training_data'
120
+ input_fn = eval_on_train_input_fn
121
+ else :
122
+ name = 'validation_data'
123
+ input_fn = eval_input_fn
124
+ model_lib .continuous_eval (estimator , FLAGS .model_dir , input_fn , eval_steps ,
125
+ train_steps , name )
184
126
185
127
186
128
if __name__ == '__main__' :
0 commit comments