@@ -864,12 +864,13 @@ def on_train_epoch_start(self, trainer, pl_module):
864
864
raise RuntimeError ("Trouble!" )
865
865
866
866
model = BoringModel ()
867
- epoch_length = 64
867
+ epoch_length = 2
868
868
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
869
869
trainer = Trainer (
870
870
default_root_dir = tmp_path ,
871
871
callbacks = [checkpoint_callback , TroublemakerOnTrainEpochStart ()],
872
872
max_epochs = 5 ,
873
+ limit_train_batches = epoch_length ,
873
874
logger = False ,
874
875
enable_progress_bar = False ,
875
876
)
@@ -887,12 +888,13 @@ def on_train_epoch_end(self, trainer, pl_module):
887
888
raise RuntimeError ("Trouble!" )
888
889
889
890
model = BoringModel ()
890
- epoch_length = 64
891
+ epoch_length = 2
891
892
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
892
893
trainer = Trainer (
893
894
default_root_dir = tmp_path ,
894
895
callbacks = [checkpoint_callback , TroublemakerOnTrainEpochEnd ()],
895
896
max_epochs = 5 ,
897
+ limit_train_batches = epoch_length ,
896
898
logger = False ,
897
899
enable_progress_bar = False ,
898
900
)
@@ -956,12 +958,13 @@ def on_validation_epoch_start(self, trainer, pl_module):
956
958
raise RuntimeError ("Trouble!" )
957
959
958
960
model = BoringModel ()
959
- epoch_length = 64
961
+ epoch_length = 2
960
962
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
961
963
trainer = Trainer (
962
964
default_root_dir = tmp_path ,
963
965
callbacks = [checkpoint_callback , TroublemakerOnValidationEpochStart ()],
964
966
max_epochs = 5 ,
967
+ limit_train_batches = epoch_length ,
965
968
logger = False ,
966
969
enable_progress_bar = False ,
967
970
)
@@ -979,12 +982,13 @@ def on_validation_epoch_end(self, trainer, pl_module):
979
982
raise RuntimeError ("Trouble!" )
980
983
981
984
model = BoringModel ()
982
- epoch_length = 64
985
+ epoch_length = 2
983
986
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
984
987
trainer = Trainer (
985
988
default_root_dir = tmp_path ,
986
989
callbacks = [checkpoint_callback , TroublemakerOnValidationEpochEnd ()],
987
990
max_epochs = 5 ,
991
+ limit_train_batches = epoch_length ,
988
992
logger = False ,
989
993
enable_progress_bar = False ,
990
994
)
@@ -1002,12 +1006,13 @@ def on_validation_start(self, trainer, pl_module):
1002
1006
raise RuntimeError ("Trouble!" )
1003
1007
1004
1008
model = BoringModel ()
1005
- epoch_length = 64
1009
+ epoch_length = 2
1006
1010
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
1007
1011
trainer = Trainer (
1008
1012
default_root_dir = tmp_path ,
1009
1013
callbacks = [checkpoint_callback , TroublemakerOnValidationStart ()],
1010
1014
max_epochs = 5 ,
1015
+ limit_train_batches = epoch_length ,
1011
1016
logger = False ,
1012
1017
enable_progress_bar = False ,
1013
1018
)
@@ -1025,12 +1030,13 @@ def on_validation_end(self, trainer, pl_module):
1025
1030
raise RuntimeError ("Trouble!" )
1026
1031
1027
1032
model = BoringModel ()
1028
- epoch_length = 64
1033
+ epoch_length = 2
1029
1034
checkpoint_callback = ModelCheckpoint (dirpath = tmp_path , filename = "{step}" , save_on_exception = True , every_n_epochs = 4 )
1030
1035
trainer = Trainer (
1031
1036
default_root_dir = tmp_path ,
1032
1037
callbacks = [checkpoint_callback , TroublemakerOnValidationEnd ()],
1033
1038
max_epochs = 5 ,
1039
+ limit_train_batches = epoch_length ,
1034
1040
logger = False ,
1035
1041
enable_progress_bar = False ,
1036
1042
)
0 commit comments