@@ -773,27 +773,25 @@ def serialized_fn(mtf_features):
773
773
gin_config_saver_hook = gin .tf .GinConfigSaverHook (
774
774
model_dir , summarize_config = True , include_step_in_filename = False )
775
775
776
+ training_hooks = [
777
+ restore_hook ,
778
+ saver_hook ,
779
+ gin_config_saver_hook ,
780
+ ]
781
+
776
782
if use_tpu :
777
783
return tpu_estimator .TPUEstimatorSpec (
778
784
mode = tf .estimator .ModeKeys .TRAIN ,
779
785
loss = tf_loss ,
780
786
train_op = train_op ,
781
787
host_call = host_call ,
782
- training_hooks = [
783
- restore_hook ,
784
- saver_hook ,
785
- gin_config_saver_hook ,
786
- ])
788
+ training_hooks = training_hooks )
787
789
else :
788
790
return tf .estimator .EstimatorSpec (
789
791
tf .estimator .ModeKeys .TRAIN ,
790
792
loss = tf_loss ,
791
793
train_op = train_op ,
792
- training_chief_hooks = [
793
- restore_hook ,
794
- saver_hook ,
795
- gin_config_saver_hook ,
796
- ])
794
+ training_chief_hooks = training_hooks )
797
795
elif mode == tf .estimator .ModeKeys .EVAL :
798
796
# perplexity eval
799
797
logits , loss = logits_and_loss (mtf_features )
@@ -1698,9 +1696,7 @@ def get_estimator(model_type, vocabulary, mesh_shape,
1698
1696
model_dir = model_dir ,
1699
1697
tpu_config = my_tpu_config ,
1700
1698
session_config = session_config ,
1701
- # We use a saver hook, so disable checkpoints here to prevent double
1702
- # saving.
1703
- save_checkpoints_steps = None ,
1699
+ save_checkpoints_steps = save_checkpoints_steps ,
1704
1700
save_checkpoints_secs = None )
1705
1701
1706
1702
transformer_model = build_model (
@@ -1748,7 +1744,7 @@ def get_estimator(model_type, vocabulary, mesh_shape,
1748
1744
def train_model (estimator , vocabulary , sequence_length , batch_size ,
1749
1745
train_dataset_fn , train_steps , ensemble_inputs ,
1750
1746
dataset_split = "train" , skip_seen_data = False ,
1751
- seen_data_init_step = 0 ):
1747
+ seen_data_init_step = 0 , checkpoint_input_pipeline = False ):
1752
1748
"""Train a Mesh-TF model.
1753
1749
1754
1750
Args:
@@ -1773,11 +1769,20 @@ def train_model(estimator, vocabulary, sequence_length, batch_size,
1773
1769
skip_seen_data: a boolean, is `False` by default. Used when a training run
1774
1770
restarts to skip already seen data. This flag is only consistent when
1775
1771
every setting (such as batch size and random seed) on the model is the
1776
- same between the original run and the new run.
1772
+ same between the original run and the new run. May require a significant
1773
+ amount of time to skip a large number of steps.
1777
1774
seen_data_init_step: an integer, when `skip_seen_data` is True, skip seen
1778
1775
steps from this starting point. Useful when finetuning.
1776
+ checkpoint_input_pipeline: a boolean, whether to checkpoint the input
1777
+ pipeline in order to restart from the previous run. May require a large
1778
+ amount of disk space for complicated input pipelines.
1779
1779
"""
1780
1780
1781
+ if skip_seen_data and checkpoint_input_pipeline :
1782
+ raise ValueError (
1783
+ "At most one of `skip_seen_data` and `checkpoint_input_pipeline` may "
1784
+ "be set." )
1785
+
1781
1786
def input_fn (params ):
1782
1787
del params
1783
1788
@@ -1799,7 +1804,12 @@ def input_fn(params):
1799
1804
dataset = dataset .skip (steps_to_skip )
1800
1805
return dataset
1801
1806
1802
- estimator .train (input_fn = input_fn , max_steps = train_steps )
1807
+ hooks = []
1808
+ if checkpoint_input_pipeline :
1809
+ hooks .append (
1810
+ tf .data .experimental .CheckpointInputPipelineHook (estimator ))
1811
+
1812
+ estimator .train (input_fn = input_fn , max_steps = train_steps , hooks = hooks )
1803
1813
1804
1814
1805
1815
@gin .configurable
@@ -2399,7 +2409,8 @@ def run(tpu_job_name,
2399
2409
train_model_fn = train_model ,
2400
2410
skip_seen_data = False ,
2401
2411
seen_data_init_step = 0 ,
2402
- output_eval_examples = True ):
2412
+ output_eval_examples = True ,
2413
+ checkpoint_input_pipeline = False ):
2403
2414
"""Run training, eval, or inference depending on `mode`.
2404
2415
2405
2416
Args:
@@ -2465,12 +2476,16 @@ def run(tpu_job_name,
2465
2476
skip_seen_data: a boolean, is `False` by default. Used when a training run
2466
2477
restarts to skip already seen data. This flag is only consistent when
2467
2478
every setting (such as batch size and random seed) on the model is the
2468
- same between the original run and the new run.
2479
+ same between the original run and the new run. May require a significant
2480
+ amount of time to skip a large number of steps.
2469
2481
seen_data_init_step: an integer, when `skip_seen_data` is True, skip seen
2470
2482
steps from this starting point. Useful when finetuning.
2471
2483
output_eval_examples: a boolean, is `True` by default. Used to decide
2472
2484
whether to output whether to dump inputs, targets, and predictions of the
2473
2485
eval examples in plaintext to eval_summary_dir.
2486
+ checkpoint_input_pipeline: a boolean, whether to checkpoint the input
2487
+ pipeline in order to restart from the previous run. May require a large
2488
+ amount of disk space for complicated input pipelines.
2474
2489
"""
2475
2490
if isinstance (sequence_length , int ):
2476
2491
sequence_length = {"inputs" : sequence_length ,
@@ -2560,7 +2575,8 @@ def run(tpu_job_name,
2560
2575
train_model_fn (estimator , vocabulary , sequence_length , batch_size ,
2561
2576
train_dataset_fn , train_steps , ensemble_inputs ,
2562
2577
skip_seen_data = skip_seen_data ,
2563
- seen_data_init_step = seen_data_init_step )
2578
+ seen_data_init_step = seen_data_init_step ,
2579
+ checkpoint_input_pipeline = checkpoint_input_pipeline )
2564
2580
2565
2581
elif mode == "perplexity_eval" :
2566
2582
if eval_dataset_fn is None :
0 commit comments