Skip to content

Commit 6249794

Browse files
committed
Update save checkpoint on exception tests to use a shorter more precisly defined epoch lenght
1 parent 985c1e1 commit 6249794

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -864,12 +864,13 @@ def on_train_epoch_start(self, trainer, pl_module):
864864
raise RuntimeError("Trouble!")
865865

866866
model = BoringModel()
867-
epoch_length = 64
867+
epoch_length = 2
868868
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
869869
trainer = Trainer(
870870
default_root_dir=tmp_path,
871871
callbacks=[checkpoint_callback, TroublemakerOnTrainEpochStart()],
872872
max_epochs=5,
873+
limit_train_batches=epoch_length,
873874
logger=False,
874875
enable_progress_bar=False,
875876
)
@@ -887,12 +888,13 @@ def on_train_epoch_end(self, trainer, pl_module):
887888
raise RuntimeError("Trouble!")
888889

889890
model = BoringModel()
890-
epoch_length = 64
891+
epoch_length = 2
891892
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
892893
trainer = Trainer(
893894
default_root_dir=tmp_path,
894895
callbacks=[checkpoint_callback, TroublemakerOnTrainEpochEnd()],
895896
max_epochs=5,
897+
limit_train_batches=epoch_length,
896898
logger=False,
897899
enable_progress_bar=False,
898900
)
@@ -956,12 +958,13 @@ def on_validation_epoch_start(self, trainer, pl_module):
956958
raise RuntimeError("Trouble!")
957959

958960
model = BoringModel()
959-
epoch_length = 64
961+
epoch_length = 2
960962
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
961963
trainer = Trainer(
962964
default_root_dir=tmp_path,
963965
callbacks=[checkpoint_callback, TroublemakerOnValidationEpochStart()],
964966
max_epochs=5,
967+
limit_train_batches=epoch_length,
965968
logger=False,
966969
enable_progress_bar=False,
967970
)
@@ -979,12 +982,13 @@ def on_validation_epoch_end(self, trainer, pl_module):
979982
raise RuntimeError("Trouble!")
980983

981984
model = BoringModel()
982-
epoch_length = 64
985+
epoch_length = 2
983986
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
984987
trainer = Trainer(
985988
default_root_dir=tmp_path,
986989
callbacks=[checkpoint_callback, TroublemakerOnValidationEpochEnd()],
987990
max_epochs=5,
991+
limit_train_batches=epoch_length,
988992
logger=False,
989993
enable_progress_bar=False,
990994
)
@@ -1002,12 +1006,13 @@ def on_validation_start(self, trainer, pl_module):
10021006
raise RuntimeError("Trouble!")
10031007

10041008
model = BoringModel()
1005-
epoch_length = 64
1009+
epoch_length = 2
10061010
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
10071011
trainer = Trainer(
10081012
default_root_dir=tmp_path,
10091013
callbacks=[checkpoint_callback, TroublemakerOnValidationStart()],
10101014
max_epochs=5,
1015+
limit_train_batches=epoch_length,
10111016
logger=False,
10121017
enable_progress_bar=False,
10131018
)
@@ -1025,12 +1030,13 @@ def on_validation_end(self, trainer, pl_module):
10251030
raise RuntimeError("Trouble!")
10261031

10271032
model = BoringModel()
1028-
epoch_length = 64
1033+
epoch_length = 2
10291034
checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="{step}", save_on_exception=True, every_n_epochs=4)
10301035
trainer = Trainer(
10311036
default_root_dir=tmp_path,
10321037
callbacks=[checkpoint_callback, TroublemakerOnValidationEnd()],
10331038
max_epochs=5,
1039+
limit_train_batches=epoch_length,
10341040
logger=False,
10351041
enable_progress_bar=False,
10361042
)

0 commit comments

Comments
 (0)