@@ -294,7 +294,11 @@ def test_no_checkpoint_and_summaries(self):
294
294
train_steps = 10 , eval_steps = 2 , eval_interval = 6 )
295
295
self .assertEqual (test_runner .global_step , 10 )
296
296
297
- def test_has_checkpoint_no_summaries (self ):
297
+ @parameterized .named_parameters (
298
+ ("_sync_checkpoint_saving" , False ),
299
+ ("_async_checkpoint_saving" , True )
300
+ )
301
+ def test_has_checkpoint_no_summaries (self , enable_async_checkpoint_saving ):
298
302
test_runner = TestRunner ()
299
303
# Has checkpoint, but no summary directories.
300
304
checkpoint = tf .train .Checkpoint (model = test_runner .model )
@@ -308,6 +312,7 @@ def test_has_checkpoint_no_summaries(self):
308
312
evaluator = test_runner ,
309
313
global_step = test_runner .global_step ,
310
314
checkpoint_manager = checkpoint_manager ,
315
+ enable_async_checkpointing = enable_async_checkpoint_saving ,
311
316
steps_per_loop = 2 )
312
317
test_controller .train_and_evaluate (
313
318
train_steps = 10 , eval_steps = 2 , eval_interval = 6 )
@@ -317,7 +322,13 @@ def test_has_checkpoint_no_summaries(self):
317
322
self .assertEmpty (tf .io .gfile .glob (
318
323
os .path .join (checkpoint_manager .directory , "events.*" )))
319
324
320
- def test_has_checkpoint_eval_summary_only (self ):
325
+ @parameterized .named_parameters (
326
+ ("_sync_checkpoint_saving" , False ),
327
+ ("_async_checkpoint_saving" , True )
328
+ )
329
+ def test_has_checkpoint_eval_summary_only (
330
+ self , enable_async_checkpoint_saving
331
+ ):
321
332
test_runner = TestRunner ()
322
333
# Has checkpoint, but no summary directories.
323
334
checkpoint = tf .train .Checkpoint (model = test_runner .model )
@@ -331,6 +342,7 @@ def test_has_checkpoint_eval_summary_only(self):
331
342
evaluator = test_runner ,
332
343
global_step = test_runner .global_step ,
333
344
checkpoint_manager = checkpoint_manager ,
345
+ enable_async_checkpointing = enable_async_checkpoint_saving ,
334
346
eval_summary_dir = os .path .join (self .model_dir , "summaries/eval" ),
335
347
steps_per_loop = 2 )
336
348
test_controller .train_and_evaluate (
@@ -344,7 +356,13 @@ def test_has_checkpoint_eval_summary_only(self):
344
356
self .assertNotEmpty (tf .io .gfile .glob (
345
357
os .path .join (self .model_dir , "summaries/eval/events.*" )))
346
358
347
- def test_restore_from_most_recent_checkpoint (self ):
359
+ @parameterized .named_parameters (
360
+ ("_sync_checkpoint_saving" , False ),
361
+ ("_async_checkpoint_saving" , True )
362
+ )
363
+ def test_restore_from_most_recent_checkpoint (
364
+ self , enable_async_checkpoint_saving
365
+ ):
348
366
test_runner = TestRunner ()
349
367
checkpoint = tf .train .Checkpoint (model = test_runner .model )
350
368
checkpoint_manager = tf .train .CheckpointManager (
@@ -357,16 +375,23 @@ def test_restore_from_most_recent_checkpoint(self):
357
375
trainer = test_runner ,
358
376
global_step = test_runner .global_step ,
359
377
checkpoint_manager = checkpoint_manager ,
378
+ enable_async_checkpointing = enable_async_checkpoint_saving ,
360
379
eval_summary_dir = os .path .join (self .model_dir , "summaries/eval" ),
361
380
steps_per_loop = 5 )
362
381
test_controller .train (20 )
363
382
self .assertLen (checkpoint_manager .checkpoints , 4 )
364
383
restored_path = test_controller .restore_checkpoint ()
365
384
self .assertEqual (restored_path , checkpoint_manager .checkpoints [- 1 ])
366
385
367
- @parameterized .named_parameters (("return_numpy" , True ),
368
- ("return_tensor" , False ))
369
- def test_train_and_evaluate (self , return_numpy ):
386
+ @parameterized .named_parameters (
387
+ ("return_numpy_sync_checkpoint_saving" , True , False ),
388
+ ("return_numpy_async_checkpoint_saving" , True , True ),
389
+ ("return_tensor_sync_checkpoint_saving" , False , False ),
390
+ ("return_tensor_async_checkpoint_saving" , False , True ),
391
+ )
392
+ def test_train_and_evaluate (
393
+ self , return_numpy , enable_async_checkpoint_saving
394
+ ):
370
395
test_runner = TestRunner (return_numpy = return_numpy )
371
396
372
397
checkpoint = tf .train .Checkpoint (
@@ -384,6 +409,7 @@ def test_train_and_evaluate(self, return_numpy):
384
409
steps_per_loop = 2 ,
385
410
summary_dir = os .path .join (self .model_dir , "summaries/train" ),
386
411
checkpoint_manager = checkpoint_manager ,
412
+ enable_async_checkpointing = enable_async_checkpoint_saving ,
387
413
eval_summary_dir = os .path .join (self .model_dir , "summaries/eval" ))
388
414
test_controller .train_and_evaluate (
389
415
train_steps = 10 , eval_steps = 2 , eval_interval = 6 )
@@ -403,7 +429,11 @@ def test_train_and_evaluate(self, return_numpy):
403
429
summaries_with_matching_keyword (
404
430
"eval_loss" , os .path .join (self .model_dir , "summaries/eval" )))
405
431
406
- def test_train_only (self ):
432
+ @parameterized .named_parameters (
433
+ ("_sync_checkpoint_saving" , False ),
434
+ ("_async_checkpoint_saving" , True )
435
+ )
436
+ def test_train_only (self , enable_async_checkpoint_saving ):
407
437
test_runner = TestRunner ()
408
438
409
439
checkpoint = tf .train .Checkpoint (
@@ -420,6 +450,7 @@ def test_train_only(self):
420
450
steps_per_loop = 2 ,
421
451
summary_dir = os .path .join (self .model_dir , "summaries/train" ),
422
452
checkpoint_manager = checkpoint_manager ,
453
+ enable_async_checkpointing = enable_async_checkpoint_saving ,
423
454
eval_summary_dir = os .path .join (self .model_dir , "summaries/eval" ),
424
455
)
425
456
test_controller .train (steps = 10 )
@@ -497,7 +528,11 @@ def test_no_eval_steps(self):
497
528
checkpoint_manager = checkpoint_manager )
498
529
test_controller .evaluate ()
499
530
500
- def test_already_trained_model (self ):
531
+ @parameterized .named_parameters (
532
+ ("_sync_checkpoint_saving" , False ),
533
+ ("_async_checkpoint_saving" , True )
534
+ )
535
+ def test_already_trained_model (self , enable_async_checkpoint_saving ):
501
536
test_runner = TestRunner ()
502
537
test_runner .global_step .assign (10 )
503
538
@@ -513,7 +548,8 @@ def test_already_trained_model(self):
513
548
trainer = test_runner ,
514
549
global_step = test_runner .global_step ,
515
550
steps_per_loop = 2 ,
516
- checkpoint_manager = checkpoint_manager )
551
+ checkpoint_manager = checkpoint_manager ,
552
+ enable_async_checkpointing = enable_async_checkpoint_saving )
517
553
# `global_step` is already `train_steps`.
518
554
test_controller .train (steps = 10 )
519
555
@@ -533,7 +569,7 @@ def test_summaries_inside_train_fn(self):
533
569
steps_per_loop = 2 ,
534
570
summary_dir = os .path .join (self .model_dir , "summaries/train" ),
535
571
summary_interval = 2 ,
536
- checkpoint_manager = checkpoint_manager ,
572
+ checkpoint_manager = checkpoint_manager
537
573
)
538
574
test_controller .train (steps = 10 )
539
575
@@ -594,6 +630,7 @@ def train_and_evaluate(self,
594
630
interval = min (train_steps - self .global_step .numpy (), eval_interval )
595
631
num_steps = self .global_step .numpy () + interval
596
632
self .train (steps = num_steps , checkpoint_at_completion = False )
633
+ self ._sync_on_async_checkpointing ()
597
634
self .evaluate (steps = eval_steps )
598
635
# Early stop condition.
599
636
if test_runner .eval_loss .result () < 0.1 :
@@ -672,7 +709,11 @@ def test_train_and_evaluate_reset_datasets(self):
672
709
test_controller .train_and_evaluate (
673
710
train_steps = 10 , eval_steps = 2 , eval_interval = 6 )
674
711
675
- def test_eval_and_checkpoint_interval (self ):
712
+ @parameterized .named_parameters (
713
+ ("_sync_checkpoint_saving" , False ),
714
+ ("_async_checkpoint_saving" , True )
715
+ )
716
+ def test_eval_and_checkpoint_interval (self , enable_async_checkpoint_saving ):
676
717
test_runner = TestRunner ()
677
718
678
719
checkpoint = tf .train .Checkpoint (
@@ -689,6 +730,7 @@ def test_eval_and_checkpoint_interval(self):
689
730
global_step = test_runner .global_step ,
690
731
steps_per_loop = 10 ,
691
732
checkpoint_manager = checkpoint_manager ,
733
+ enable_async_checkpointing = enable_async_checkpoint_saving ,
692
734
summary_dir = self .model_dir )
693
735
test_controller .train_and_evaluate (
694
736
train_steps = 10 , eval_steps = 2 , eval_interval = 5 )
@@ -803,7 +845,7 @@ def steps_per_loop_fn(global_step):
803
845
trainer = test_runner ,
804
846
global_step = test_runner .global_step ,
805
847
steps_per_loop = steps_per_loop_fn ,
806
- checkpoint_manager = checkpoint_manager ,
848
+ checkpoint_manager = checkpoint_manager
807
849
)
808
850
test_controller .train (steps = 10 )
809
851
self .assertEqual (test_runner .global_step , 10 )
0 commit comments