Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 27613a9

Browse files
adarobMesh TensorFlow Team
authored and
Mesh TensorFlow Team
committed
Add optional checkpointing of input pipeline during training.
PiperOrigin-RevId: 359098125
1 parent 9625f34 commit 27613a9

File tree

1 file changed

+35
-19
lines changed

1 file changed

+35
-19
lines changed

mesh_tensorflow/transformer/utils.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -773,27 +773,25 @@ def serialized_fn(mtf_features):
773773
gin_config_saver_hook = gin.tf.GinConfigSaverHook(
774774
model_dir, summarize_config=True, include_step_in_filename=False)
775775

776+
training_hooks = [
777+
restore_hook,
778+
saver_hook,
779+
gin_config_saver_hook,
780+
]
781+
776782
if use_tpu:
777783
return tpu_estimator.TPUEstimatorSpec(
778784
mode=tf.estimator.ModeKeys.TRAIN,
779785
loss=tf_loss,
780786
train_op=train_op,
781787
host_call=host_call,
782-
training_hooks=[
783-
restore_hook,
784-
saver_hook,
785-
gin_config_saver_hook,
786-
])
788+
training_hooks=training_hooks)
787789
else:
788790
return tf.estimator.EstimatorSpec(
789791
tf.estimator.ModeKeys.TRAIN,
790792
loss=tf_loss,
791793
train_op=train_op,
792-
training_chief_hooks=[
793-
restore_hook,
794-
saver_hook,
795-
gin_config_saver_hook,
796-
])
794+
training_chief_hooks=training_hooks)
797795
elif mode == tf.estimator.ModeKeys.EVAL:
798796
# perplexity eval
799797
logits, loss = logits_and_loss(mtf_features)
@@ -1698,9 +1696,7 @@ def get_estimator(model_type, vocabulary, mesh_shape,
16981696
model_dir=model_dir,
16991697
tpu_config=my_tpu_config,
17001698
session_config=session_config,
1701-
# We use a saver hook, so disable checkpoints here to prevent double
1702-
# saving.
1703-
save_checkpoints_steps=None,
1699+
save_checkpoints_steps=save_checkpoints_steps,
17041700
save_checkpoints_secs=None)
17051701

17061702
transformer_model = build_model(
@@ -1748,7 +1744,7 @@ def get_estimator(model_type, vocabulary, mesh_shape,
17481744
def train_model(estimator, vocabulary, sequence_length, batch_size,
17491745
train_dataset_fn, train_steps, ensemble_inputs,
17501746
dataset_split="train", skip_seen_data=False,
1751-
seen_data_init_step=0):
1747+
seen_data_init_step=0, checkpoint_input_pipeline=False):
17521748
"""Train a Mesh-TF model.
17531749
17541750
Args:
@@ -1773,11 +1769,20 @@ def train_model(estimator, vocabulary, sequence_length, batch_size,
17731769
skip_seen_data: a boolean, is `False` by default. Used when a training run
17741770
restarts to skip already seen data. This flag is only consistent when
17751771
every setting (such as batch size and random seed) on the model is the
1776-
same between the original run and the new run.
1772+
same between the original run and the new run. May require a significant
1773+
amount of time to skip a large number of steps.
17771774
seen_data_init_step: an integer, when `skip_seen_data` is True, skip seen
17781775
steps from this starting point. Useful when finetuning.
1776+
checkpoint_input_pipeline: a boolean, whether to checkpoint the input
1777+
pipeline in order to restart from the previous run. May require a large
1778+
amount of disk space for complicated input pipelines.
17791779
"""
17801780

1781+
if skip_seen_data and checkpoint_input_pipeline:
1782+
raise ValueError(
1783+
"At most one of `skip_seen_data` and `checkpoint_input_pipeline` may "
1784+
"be set.")
1785+
17811786
def input_fn(params):
17821787
del params
17831788

@@ -1799,7 +1804,12 @@ def input_fn(params):
17991804
dataset = dataset.skip(steps_to_skip)
18001805
return dataset
18011806

1802-
estimator.train(input_fn=input_fn, max_steps=train_steps)
1807+
hooks = []
1808+
if checkpoint_input_pipeline:
1809+
hooks.append(
1810+
tf.data.experimental.CheckpointInputPipelineHook(estimator))
1811+
1812+
estimator.train(input_fn=input_fn, max_steps=train_steps, hooks=hooks)
18031813

18041814

18051815
@gin.configurable
@@ -2399,7 +2409,8 @@ def run(tpu_job_name,
23992409
train_model_fn=train_model,
24002410
skip_seen_data=False,
24012411
seen_data_init_step=0,
2402-
output_eval_examples=True):
2412+
output_eval_examples=True,
2413+
checkpoint_input_pipeline=False):
24032414
"""Run training, eval, or inference depending on `mode`.
24042415
24052416
Args:
@@ -2465,12 +2476,16 @@ def run(tpu_job_name,
24652476
skip_seen_data: a boolean, is `False` by default. Used when a training run
24662477
restarts to skip already seen data. This flag is only consistent when
24672478
every setting (such as batch size and random seed) on the model is the
2468-
same between the original run and the new run.
2479+
same between the original run and the new run. May require a significant
2480+
amount of time to skip a large number of steps.
24692481
seen_data_init_step: an integer, when `skip_seen_data` is True, skip seen
24702482
steps from this starting point. Useful when finetuning.
24712483
output_eval_examples: a boolean, is `True` by default. Used to decide
24722484
whether to output whether to dump inputs, targets, and predictions of the
24732485
eval examples in plaintext to eval_summary_dir.
2486+
checkpoint_input_pipeline: a boolean, whether to checkpoint the input
2487+
pipeline in order to restart from the previous run. May require a large
2488+
amount of disk space for complicated input pipelines.
24742489
"""
24752490
if isinstance(sequence_length, int):
24762491
sequence_length = {"inputs": sequence_length,
@@ -2560,7 +2575,8 @@ def run(tpu_job_name,
25602575
train_model_fn(estimator, vocabulary, sequence_length, batch_size,
25612576
train_dataset_fn, train_steps, ensemble_inputs,
25622577
skip_seen_data=skip_seen_data,
2563-
seen_data_init_step=seen_data_init_step)
2578+
seen_data_init_step=seen_data_init_step,
2579+
checkpoint_input_pipeline=checkpoint_input_pipeline)
25642580

25652581
elif mode == "perplexity_eval":
25662582
if eval_dataset_fn is None:

0 commit comments

Comments
 (0)