@@ -724,7 +724,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
724
724
725
725
@staticmethod
726
726
def _is_done (state : State ) -> bool :
727
- return state .iteration == state .epoch_length * state .max_epochs # type: ignore[operator]
727
+ is_done_count = (
728
+ state .epoch_length is not None
729
+ and state .max_epochs is not None
730
+ and state .iteration >= state .epoch_length * state .max_epochs
731
+ )
732
+ is_done_epochs = state .max_epochs is not None and state .epoch >= state .max_epochs
733
+ return is_done_count or is_done_epochs
728
734
729
735
def set_data (self , data : Union [Iterable , DataLoader ]) -> None :
730
736
"""Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -956,7 +962,6 @@ def _internal_run_as_gen(self) -> Generator:
956
962
self .state .times [Events .EPOCH_COMPLETED .name ] = epoch_time_taken
957
963
958
964
handlers_start_time = time .time ()
959
-
960
965
self ._fire_event (Events .EPOCH_COMPLETED )
961
966
epoch_time_taken += time .time () - handlers_start_time
962
967
# update time wrt handlers
@@ -1039,13 +1044,8 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
1039
1044
1040
1045
# Should exit while loop if we can not iterate
1041
1046
if should_exit :
1042
- if not self ._is_done (self .state ):
1043
- total_iters = (
1044
- self .state .epoch_length * self .state .max_epochs
1045
- if self .state .max_epochs is not None
1046
- else self .state .max_iters
1047
- )
1048
-
1047
+ if not self ._is_done (self .state ) and self .state .max_epochs is not None :
1048
+ total_iters = self .state .epoch_length * self .state .max_epochs
1049
1049
warnings .warn (
1050
1050
"Data iterator can not provide data anymore but required total number of "
1051
1051
"iterations to run is not reached. "
@@ -1072,10 +1072,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
1072
1072
if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
1073
1073
break
1074
1074
1075
- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
1076
- self .should_terminate = True
1077
- raise _EngineTerminateException ()
1078
-
1079
1075
except _EngineTerminateSingleEpochException :
1080
1076
self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
1081
1077
self .should_terminate_single_epoch = False
@@ -1191,19 +1187,12 @@ def _run_once_on_dataset_legacy(self) -> float:
1191
1187
if self .state .epoch_length is None :
1192
1188
# Define epoch length and stop the epoch
1193
1189
self .state .epoch_length = iter_counter
1194
- if self .state .max_iters is not None :
1195
- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
1196
1190
break
1197
1191
1198
1192
# Should exit while loop if we can not iterate
1199
1193
if should_exit :
1200
- if not self ._is_done (self .state ):
1201
- total_iters = (
1202
- self .state .epoch_length * self .state .max_epochs
1203
- if self .state .max_epochs is not None
1204
- else self .state .max_iters
1205
- )
1206
-
1194
+ if not self ._is_done (self .state ) and self .state .max_epochs is not None :
1195
+ total_iters = self .state .epoch_length * self .state .max_epochs
1207
1196
warnings .warn (
1208
1197
"Data iterator can not provide data anymore but required total number of "
1209
1198
"iterations to run is not reached. "
@@ -1230,10 +1219,6 @@ def _run_once_on_dataset_legacy(self) -> float:
1230
1219
if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
1231
1220
break
1232
1221
1233
- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
1234
- self .should_terminate = True
1235
- raise _EngineTerminateException ()
1236
-
1237
1222
except _EngineTerminateSingleEpochException :
1238
1223
self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
1239
1224
self .should_terminate_single_epoch = False
0 commit comments