Skip to content

Commit e059e33

Browse files
Support async checkpoint in Orbit trainer/controller.
This CL adds a field in Orbit trainer/controller indicating whether async checkpoint is enabled for checkpoint saving. BY default this value is set to False, which is equivalent to the existing behavior. In addition, a sync barrier is added at the end of training (in controller) to make sure users code won't prematurely access the checkpoint file/state when the async checkpoint saving is still ongoing. PiperOrigin-RevId: 529300903
1 parent f0ec955 commit e059e33

File tree

2 files changed

+79
-12
lines changed

2 files changed

+79
-12
lines changed

orbit/controller.py

+25
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(
9696
# Train related
9797
steps_per_loop: Optional[Union[int, Callable[[int], int]]] = None,
9898
checkpoint_manager: Optional[tf.train.CheckpointManager] = None,
99+
enable_async_checkpointing: bool = False,
99100
# Summary related
100101
summary_interval: Optional[int] = None,
101102
summary_dir: Optional[str] = None,
@@ -141,6 +142,8 @@ def __init__(
141142
the model will be restored from the most recent checkpoint inside this
142143
`__init__` method. If not provided, the `Controller` will not
143144
automatically save to or restore from checkpoints.
145+
enable_async_checkpointing: Optional bool indicating whether to enable
146+
async checkpoint saving.
144147
summary_interval: Step interval for training summaries. Note that this
145148
argument only applies to `tf.summary` calls inside the `trainer.train`
146149
function. Summaries written by the `Controller` (specifically
@@ -204,6 +207,10 @@ def __init__(
204207

205208
self.global_step = global_step
206209
self.checkpoint_manager = checkpoint_manager
210+
self._enable_async_checkpoint_saving = enable_async_checkpointing
211+
self._checkpoint_options = tf.train.CheckpointOptions(
212+
enable_async=enable_async_checkpointing
213+
)
207214

208215
if self.trainer is not None:
209216
self.step_timer = None
@@ -244,6 +251,10 @@ def train(self, steps: int, checkpoint_at_completion: bool = True):
244251
`CheckpointManager` was passed to `Controller.__init__`) and summarize
245252
training output (if `summary_dir` is set).
246253
254+
When async checkpointing is enabled, a sync is triggered at the end of this
255+
method to make sure any ongoing async checkpoint saving is finished before
256+
returning.
257+
247258
Args:
248259
steps: The global step count to train up to.
249260
checkpoint_at_completion: Whether to save a checkpoint when this method
@@ -264,6 +275,8 @@ def train(self, steps: int, checkpoint_at_completion: bool = True):
264275
if checkpoint_at_completion:
265276
self._maybe_save_checkpoint(check_interval=False)
266277

278+
self._sync_on_async_checkpointing()
279+
267280
def evaluate(self, steps: int = -1) -> Optional[runner.Output]:
268281
"""Runs evaluation for the given number of steps.
269282
@@ -339,6 +352,10 @@ def train_and_evaluate(
339352
In addition, this method will run a final evaluation at the end of the
340353
training sequence.
341354
355+
When async checkpointing is enabled, a sync is triggered at the end of this
356+
method to make sure any ongoing async checkpoint saving is finished before
357+
returning.
358+
342359
Args:
343360
train_steps: The global step count to train up to.
344361
eval_steps: The number of steps to run during an evaluation. If -1, this
@@ -365,6 +382,7 @@ def train_and_evaluate(
365382
output = self.evaluate(steps=eval_steps)
366383
current_step = self.global_step.numpy()
367384
self._maybe_save_checkpoint(check_interval=False)
385+
self._sync_on_async_checkpointing()
368386
return output
369387

370388
def evaluate_continuously(
@@ -539,6 +557,13 @@ def _require(self, attribute, for_method):
539557
f"`{attribute}` is not set. Pass `{attribute}` to "
540558
f"`Controller.__init__` before calling `{for_method}()`.")
541559

560+
def _sync_on_async_checkpointing(self):
561+
"""Force to wait for the async checkpoint saving (if any) to finish."""
562+
# pylint: disable=protected-access
563+
if self.checkpoint_manager:
564+
logging.info("Sync on async checkpoint saving.")
565+
self.checkpoint_manager.sync()
566+
542567

543568
class StepTimer:
544569
"""Utility class for measuring steps/second."""

orbit/controller_test.py

+54-12
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,11 @@ def test_no_checkpoint_and_summaries(self):
294294
train_steps=10, eval_steps=2, eval_interval=6)
295295
self.assertEqual(test_runner.global_step, 10)
296296

297-
def test_has_checkpoint_no_summaries(self):
297+
@parameterized.named_parameters(
298+
("_sync_checkpoint_saving", False),
299+
("_async_checkpoint_saving", True)
300+
)
301+
def test_has_checkpoint_no_summaries(self, enable_async_checkpoint_saving):
298302
test_runner = TestRunner()
299303
# Has checkpoint, but no summary directories.
300304
checkpoint = tf.train.Checkpoint(model=test_runner.model)
@@ -308,6 +312,7 @@ def test_has_checkpoint_no_summaries(self):
308312
evaluator=test_runner,
309313
global_step=test_runner.global_step,
310314
checkpoint_manager=checkpoint_manager,
315+
enable_async_checkpointing=enable_async_checkpoint_saving,
311316
steps_per_loop=2)
312317
test_controller.train_and_evaluate(
313318
train_steps=10, eval_steps=2, eval_interval=6)
@@ -317,7 +322,13 @@ def test_has_checkpoint_no_summaries(self):
317322
self.assertEmpty(tf.io.gfile.glob(
318323
os.path.join(checkpoint_manager.directory, "events.*")))
319324

320-
def test_has_checkpoint_eval_summary_only(self):
325+
@parameterized.named_parameters(
326+
("_sync_checkpoint_saving", False),
327+
("_async_checkpoint_saving", True)
328+
)
329+
def test_has_checkpoint_eval_summary_only(
330+
self, enable_async_checkpoint_saving
331+
):
321332
test_runner = TestRunner()
322333
# Has checkpoint, but no summary directories.
323334
checkpoint = tf.train.Checkpoint(model=test_runner.model)
@@ -331,6 +342,7 @@ def test_has_checkpoint_eval_summary_only(self):
331342
evaluator=test_runner,
332343
global_step=test_runner.global_step,
333344
checkpoint_manager=checkpoint_manager,
345+
enable_async_checkpointing=enable_async_checkpoint_saving,
334346
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
335347
steps_per_loop=2)
336348
test_controller.train_and_evaluate(
@@ -344,7 +356,13 @@ def test_has_checkpoint_eval_summary_only(self):
344356
self.assertNotEmpty(tf.io.gfile.glob(
345357
os.path.join(self.model_dir, "summaries/eval/events.*")))
346358

347-
def test_restore_from_most_recent_checkpoint(self):
359+
@parameterized.named_parameters(
360+
("_sync_checkpoint_saving", False),
361+
("_async_checkpoint_saving", True)
362+
)
363+
def test_restore_from_most_recent_checkpoint(
364+
self, enable_async_checkpoint_saving
365+
):
348366
test_runner = TestRunner()
349367
checkpoint = tf.train.Checkpoint(model=test_runner.model)
350368
checkpoint_manager = tf.train.CheckpointManager(
@@ -357,16 +375,23 @@ def test_restore_from_most_recent_checkpoint(self):
357375
trainer=test_runner,
358376
global_step=test_runner.global_step,
359377
checkpoint_manager=checkpoint_manager,
378+
enable_async_checkpointing=enable_async_checkpoint_saving,
360379
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
361380
steps_per_loop=5)
362381
test_controller.train(20)
363382
self.assertLen(checkpoint_manager.checkpoints, 4)
364383
restored_path = test_controller.restore_checkpoint()
365384
self.assertEqual(restored_path, checkpoint_manager.checkpoints[-1])
366385

367-
@parameterized.named_parameters(("return_numpy", True),
368-
("return_tensor", False))
369-
def test_train_and_evaluate(self, return_numpy):
386+
@parameterized.named_parameters(
387+
("return_numpy_sync_checkpoint_saving", True, False),
388+
("return_numpy_async_checkpoint_saving", True, True),
389+
("return_tensor_sync_checkpoint_saving", False, False),
390+
("return_tensor_async_checkpoint_saving", False, True),
391+
)
392+
def test_train_and_evaluate(
393+
self, return_numpy, enable_async_checkpoint_saving
394+
):
370395
test_runner = TestRunner(return_numpy=return_numpy)
371396

372397
checkpoint = tf.train.Checkpoint(
@@ -384,6 +409,7 @@ def test_train_and_evaluate(self, return_numpy):
384409
steps_per_loop=2,
385410
summary_dir=os.path.join(self.model_dir, "summaries/train"),
386411
checkpoint_manager=checkpoint_manager,
412+
enable_async_checkpointing=enable_async_checkpoint_saving,
387413
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"))
388414
test_controller.train_and_evaluate(
389415
train_steps=10, eval_steps=2, eval_interval=6)
@@ -403,7 +429,11 @@ def test_train_and_evaluate(self, return_numpy):
403429
summaries_with_matching_keyword(
404430
"eval_loss", os.path.join(self.model_dir, "summaries/eval")))
405431

406-
def test_train_only(self):
432+
@parameterized.named_parameters(
433+
("_sync_checkpoint_saving", False),
434+
("_async_checkpoint_saving", True)
435+
)
436+
def test_train_only(self, enable_async_checkpoint_saving):
407437
test_runner = TestRunner()
408438

409439
checkpoint = tf.train.Checkpoint(
@@ -420,6 +450,7 @@ def test_train_only(self):
420450
steps_per_loop=2,
421451
summary_dir=os.path.join(self.model_dir, "summaries/train"),
422452
checkpoint_manager=checkpoint_manager,
453+
enable_async_checkpointing=enable_async_checkpoint_saving,
423454
eval_summary_dir=os.path.join(self.model_dir, "summaries/eval"),
424455
)
425456
test_controller.train(steps=10)
@@ -497,7 +528,11 @@ def test_no_eval_steps(self):
497528
checkpoint_manager=checkpoint_manager)
498529
test_controller.evaluate()
499530

500-
def test_already_trained_model(self):
531+
@parameterized.named_parameters(
532+
("_sync_checkpoint_saving", False),
533+
("_async_checkpoint_saving", True)
534+
)
535+
def test_already_trained_model(self, enable_async_checkpoint_saving):
501536
test_runner = TestRunner()
502537
test_runner.global_step.assign(10)
503538

@@ -513,7 +548,8 @@ def test_already_trained_model(self):
513548
trainer=test_runner,
514549
global_step=test_runner.global_step,
515550
steps_per_loop=2,
516-
checkpoint_manager=checkpoint_manager)
551+
checkpoint_manager=checkpoint_manager,
552+
enable_async_checkpointing=enable_async_checkpoint_saving)
517553
# `global_step` is already `train_steps`.
518554
test_controller.train(steps=10)
519555

@@ -533,7 +569,7 @@ def test_summaries_inside_train_fn(self):
533569
steps_per_loop=2,
534570
summary_dir=os.path.join(self.model_dir, "summaries/train"),
535571
summary_interval=2,
536-
checkpoint_manager=checkpoint_manager,
572+
checkpoint_manager=checkpoint_manager
537573
)
538574
test_controller.train(steps=10)
539575

@@ -594,6 +630,7 @@ def train_and_evaluate(self,
594630
interval = min(train_steps - self.global_step.numpy(), eval_interval)
595631
num_steps = self.global_step.numpy() + interval
596632
self.train(steps=num_steps, checkpoint_at_completion=False)
633+
self._sync_on_async_checkpointing()
597634
self.evaluate(steps=eval_steps)
598635
# Early stop condition.
599636
if test_runner.eval_loss.result() < 0.1:
@@ -672,7 +709,11 @@ def test_train_and_evaluate_reset_datasets(self):
672709
test_controller.train_and_evaluate(
673710
train_steps=10, eval_steps=2, eval_interval=6)
674711

675-
def test_eval_and_checkpoint_interval(self):
712+
@parameterized.named_parameters(
713+
("_sync_checkpoint_saving", False),
714+
("_async_checkpoint_saving", True)
715+
)
716+
def test_eval_and_checkpoint_interval(self, enable_async_checkpoint_saving):
676717
test_runner = TestRunner()
677718

678719
checkpoint = tf.train.Checkpoint(
@@ -689,6 +730,7 @@ def test_eval_and_checkpoint_interval(self):
689730
global_step=test_runner.global_step,
690731
steps_per_loop=10,
691732
checkpoint_manager=checkpoint_manager,
733+
enable_async_checkpointing=enable_async_checkpoint_saving,
692734
summary_dir=self.model_dir)
693735
test_controller.train_and_evaluate(
694736
train_steps=10, eval_steps=2, eval_interval=5)
@@ -803,7 +845,7 @@ def steps_per_loop_fn(global_step):
803845
trainer=test_runner,
804846
global_step=test_runner.global_step,
805847
steps_per_loop=steps_per_loop_fn,
806-
checkpoint_manager=checkpoint_manager,
848+
checkpoint_manager=checkpoint_manager
807849
)
808850
test_controller.train(steps=10)
809851
self.assertEqual(test_runner.global_step, 10)

0 commit comments

Comments
 (0)