Skip to content

Commit a60dd98

Browse files
committed
Refactor model_tpu_main.py files and move continuous eval loop into model_lib.py
PiperOrigin-RevId: 192512429
1 parent f98f000 commit a60dd98

File tree

2 files changed

+61
-76
lines changed

2 files changed

+61
-76
lines changed

research/object_detection/model_lib.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import print_function
2020

2121
import functools
22+
import os
2223

2324
import tensorflow as tf
2425

@@ -574,6 +575,48 @@ def create_train_and_eval_specs(train_input_fn,
574575
return train_spec, eval_specs
575576

576577

578+
def continuous_eval(estimator, model_dir, input_fn, eval_steps, train_steps,
579+
name):
580+
"""Perform continuous evaluation on checkpoints written to a model directory.
581+
582+
Args:
583+
estimator: Estimator object to use for evaluation.
584+
model_dir: Model directory to read checkpoints for continuous evaluation.
585+
input_fn: Input function to use for evaluation.
586+
eval_steps: Number of steps to run during each evaluation.
587+
train_steps: Number of training steps. This is used to infer the last
588+
checkpoint and stop evaluation loop.
589+
name: Namescope for eval summary.
590+
"""
591+
def terminate_eval():
592+
tf.logging.info('Terminating eval after 180 seconds of no checkpoints')
593+
return True
594+
595+
for ckpt in tf.contrib.training.checkpoints_iterator(
596+
model_dir, min_interval_secs=180, timeout=None,
597+
timeout_fn=terminate_eval):
598+
599+
tf.logging.info('Starting Evaluation.')
600+
try:
601+
eval_results = estimator.evaluate(
602+
input_fn=input_fn,
603+
steps=eval_steps,
604+
checkpoint_path=ckpt,
605+
name=name)
606+
tf.logging.info('Eval results: %s' % eval_results)
607+
608+
# Terminate eval job when final checkpoint is reached
609+
current_step = int(os.path.basename(ckpt).split('-')[1])
610+
if current_step >= train_steps:
611+
tf.logging.info(
612+
'Evaluation finished after training step %d' % current_step)
613+
break
614+
615+
except tf.errors.NotFoundError:
616+
tf.logging.info(
617+
'Checkpoint %s no longer exists, skipping checkpoint' % ckpt)
618+
619+
577620
def populate_experiment(run_config,
578621
hparams,
579622
pipeline_config_path,

research/object_detection/model_tpu_main.py

Lines changed: 18 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,10 @@
2222
from __future__ import division
2323
from __future__ import print_function
2424

25-
import os
2625
from absl import flags
2726
import tensorflow as tf
2827

2928
from tensorflow.contrib.tpu.python.tpu import tpu_config
30-
from tensorflow.contrib.training.python.training import evaluation
3129

3230
from object_detection import model_hparams
3331
from object_detection import model_lib
@@ -48,31 +46,18 @@
4846
flags.DEFINE_string(
4947
'tpu_name',
5048
default=None,
51-
help='Name of the Cloud TPU for Cluster Resolvers. You must specify either '
52-
'this flag or --master.')
53-
54-
flags.DEFINE_string(
55-
'master',
56-
default=None,
57-
help='GRPC URL of the master (e.g. grpc://ip.address.of.tpu:8470). You '
58-
'must specify either this flag or --tpu_name.')
49+
help='Name of the Cloud TPU for Cluster Resolvers.')
5950

6051
flags.DEFINE_integer('num_shards', 8, 'Number of shards (TPU cores).')
6152
flags.DEFINE_integer('iterations_per_loop', 100,
6253
'Number of iterations per TPU training loop.')
6354
# For mode=train_and_eval, evaluation occurs after training is finished.
6455
# Note: independently of steps_per_checkpoint, estimator will save the most
6556
# recent checkpoint every 10 minutes by default for train_and_eval
66-
flags.DEFINE_string('mode', 'train_and_eval',
67-
'Mode to run: train, eval, train_and_eval')
57+
flags.DEFINE_string('mode', 'train',
58+
'Mode to run: train, eval')
6859
flags.DEFINE_integer('train_batch_size', 32 * 8, 'Batch size for training.')
6960

70-
# For EVAL.
71-
flags.DEFINE_integer('min_eval_interval_secs', 180,
72-
'Minimum seconds between evaluations.')
73-
flags.DEFINE_integer(
74-
'eval_timeout_secs', None,
75-
'Maximum seconds between checkpoints before evaluation terminates.')
7661
flags.DEFINE_string(
7762
'hparams_overrides', None, 'Comma-separated list of '
7863
'hyperparameters to override defaults.')
@@ -93,21 +78,12 @@ def main(unused_argv):
9378
flags.mark_flag_as_required('model_dir')
9479
flags.mark_flag_as_required('pipeline_config_path')
9580

96-
if FLAGS.master is None and FLAGS.tpu_name is None:
97-
raise RuntimeError('You must specify either --master or --tpu_name.')
98-
99-
if FLAGS.master is not None:
100-
if FLAGS.tpu_name is not None:
101-
tf.logging.warn('Both --master and --tpu_name are set. Ignoring '
102-
'--tpu_name and using --master.')
103-
tpu_grpc_url = FLAGS.master
104-
else:
105-
tpu_cluster_resolver = (
106-
tf.contrib.cluster_resolver.python.training.TPUClusterResolver(
107-
tpu_names=[FLAGS.tpu_name],
108-
zone=FLAGS.tpu_zone,
109-
project=FLAGS.gcp_project))
110-
tpu_grpc_url = tpu_cluster_resolver.get_master()
81+
tpu_cluster_resolver = (
82+
tf.contrib.cluster_resolver.python.training.TPUClusterResolver(
83+
tpu_names=[FLAGS.tpu_name],
84+
zone=FLAGS.tpu_zone,
85+
project=FLAGS.gcp_project))
86+
tpu_grpc_url = tpu_cluster_resolver.get_master()
11187

11288
config = tpu_config.RunConfig(
11389
master=tpu_grpc_url,
@@ -134,53 +110,19 @@ def main(unused_argv):
134110
train_steps = train_and_eval_dict['train_steps']
135111
eval_steps = train_and_eval_dict['eval_steps']
136112

137-
if FLAGS.mode in ['train', 'train_and_eval']:
113+
if FLAGS.mode == 'train':
138114
estimator.train(input_fn=train_input_fn, max_steps=train_steps)
139115

140-
if FLAGS.mode == 'train_and_eval':
141-
# Eval one time.
142-
eval_results = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
143-
tf.logging.info('Eval results: %s' % eval_results)
144-
145116
# Continuously evaluating.
146117
if FLAGS.mode == 'eval':
147-
def terminate_eval():
148-
tf.logging.info('Terminating eval after %d seconds of no checkpoints' %
149-
FLAGS.eval_timeout_secs)
150-
return True
151-
152-
# Run evaluation when there's a new checkpoint.
153-
for ckpt in evaluation.checkpoints_iterator(
154-
FLAGS.model_dir,
155-
min_interval_secs=FLAGS.min_eval_interval_secs,
156-
timeout=FLAGS.eval_timeout_secs,
157-
timeout_fn=terminate_eval):
158-
159-
tf.logging.info('Starting to evaluate.')
160-
if FLAGS.eval_training_data:
161-
name = 'training_data'
162-
input_fn = eval_on_train_input_fn
163-
else:
164-
name = 'validation_data'
165-
input_fn = eval_input_fn
166-
try:
167-
eval_results = estimator.evaluate(
168-
input_fn=input_fn,
169-
steps=eval_steps,
170-
checkpoint_path=ckpt,
171-
name=name)
172-
tf.logging.info('Eval results: %s' % eval_results)
173-
174-
# Terminate eval job when final checkpoint is reached
175-
current_step = int(os.path.basename(ckpt).split('-')[1])
176-
if current_step >= train_steps:
177-
tf.logging.info(
178-
'Evaluation finished after training step %d' % current_step)
179-
break
180-
181-
except tf.errors.NotFoundError:
182-
tf.logging.info(
183-
'Checkpoint %s no longer exists, skipping checkpoint' % ckpt)
118+
if FLAGS.eval_training_data:
119+
name = 'training_data'
120+
input_fn = eval_on_train_input_fn
121+
else:
122+
name = 'validation_data'
123+
input_fn = eval_input_fn
124+
model_lib.continuous_eval(estimator, FLAGS.model_dir, input_fn, eval_steps,
125+
train_steps, name)
184126

185127

186128
if __name__ == '__main__':

0 commit comments

Comments
 (0)