Skip to content

Commit f264179

Browse files
committed
Fixed other things due to reverted commits
1 parent c7348f7 commit f264179

File tree

8 files changed

+17
-72
lines changed

8 files changed

+17
-72
lines changed

ignite/contrib/handlers/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from ignite.contrib.handlers.clearml_logger import ClearMLLogger
22
from ignite.contrib.handlers.custom_events import CustomPeriodicEvent
3-
from ignite.contrib.handlers.lr_finder import FastaiLRFinder
43
from ignite.contrib.handlers.mlflow_logger import MLflowLogger
54
from ignite.contrib.handlers.neptune_logger import NeptuneLogger
65
from ignite.contrib.handlers.polyaxon_logger import PolyaxonLogger

ignite/engine/engine.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
724724

725725
@staticmethod
726726
def _is_done(state: State) -> bool:
727-
return state.iteration == state.epoch_length * state.max_epochs # type: ignore[operator]
727+
is_done_count = (
728+
state.epoch_length is not None
729+
and state.max_epochs is not None
730+
and state.iteration >= state.epoch_length * state.max_epochs
731+
)
732+
is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs
733+
return is_done_count or is_done_epochs
728734

729735
def set_data(self, data: Union[Iterable, DataLoader]) -> None:
730736
"""Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -956,7 +962,6 @@ def _internal_run_as_gen(self) -> Generator:
956962
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
957963

958964
handlers_start_time = time.time()
959-
960965
self._fire_event(Events.EPOCH_COMPLETED)
961966
epoch_time_taken += time.time() - handlers_start_time
962967
# update time wrt handlers
@@ -1039,13 +1044,8 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
10391044

10401045
# Should exit while loop if we can not iterate
10411046
if should_exit:
1042-
if not self._is_done(self.state):
1043-
total_iters = (
1044-
self.state.epoch_length * self.state.max_epochs
1045-
if self.state.max_epochs is not None
1046-
else self.state.max_iters
1047-
)
1048-
1047+
if not self._is_done(self.state) and self.state.max_epochs is not None:
1048+
total_iters = self.state.epoch_length * self.state.max_epochs
10491049
warnings.warn(
10501050
"Data iterator can not provide data anymore but required total number of "
10511051
"iterations to run is not reached. "
@@ -1072,10 +1072,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
10721072
if self.state.epoch_length is not None and iter_counter == self.state.epoch_length:
10731073
break
10741074

1075-
if self.state.max_iters is not None and self.state.iteration == self.state.max_iters:
1076-
self.should_terminate = True
1077-
raise _EngineTerminateException()
1078-
10791075
except _EngineTerminateSingleEpochException:
10801076
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
10811077
self.should_terminate_single_epoch = False
@@ -1191,19 +1187,12 @@ def _run_once_on_dataset_legacy(self) -> float:
11911187
if self.state.epoch_length is None:
11921188
# Define epoch length and stop the epoch
11931189
self.state.epoch_length = iter_counter
1194-
if self.state.max_iters is not None:
1195-
self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length)
11961190
break
11971191

11981192
# Should exit while loop if we can not iterate
11991193
if should_exit:
1200-
if not self._is_done(self.state):
1201-
total_iters = (
1202-
self.state.epoch_length * self.state.max_epochs
1203-
if self.state.max_epochs is not None
1204-
else self.state.max_iters
1205-
)
1206-
1194+
if not self._is_done(self.state) and self.state.max_epochs is not None:
1195+
total_iters = self.state.epoch_length * self.state.max_epochs
12071196
warnings.warn(
12081197
"Data iterator can not provide data anymore but required total number of "
12091198
"iterations to run is not reached. "
@@ -1230,10 +1219,6 @@ def _run_once_on_dataset_legacy(self) -> float:
12301219
if self.state.epoch_length is not None and iter_counter == self.state.epoch_length:
12311220
break
12321221

1233-
if self.state.max_iters is not None and self.state.iteration == self.state.max_iters:
1234-
self.should_terminate = True
1235-
raise _EngineTerminateException()
1236-
12371222
except _EngineTerminateSingleEpochException:
12381223
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
12391224
self.should_terminate_single_epoch = False

ignite/engine/events.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
214214
)
215215

216216

217-
class EventEnum(CallableEventWithFilter, Enum): # type: ignore[misc]
217+
class EventEnum(CallableEventWithFilter, Enum):
218218
"""Base class for all :class:`~ignite.engine.events.Events`. User defined custom events should also inherit
219219
this class.
220220

ignite/handlers/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,7 @@ def __init__(
962962
self,
963963
dirname: Union[str, Path],
964964
filename_prefix: str = "",
965-
save_interval: Optional[Callable] = None,
965+
save_interval: Optional[int] = None,
966966
score_function: Optional[Callable] = None,
967967
score_name: Optional[str] = None,
968968
n_saved: Union[int, None] = 1,

ignite/handlers/lr_finder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def _run(
106106
max_iter = trainer.state.epoch_length * trainer.state.max_epochs # type: ignore[operator]
107107
if max_iter < num_iter:
108108
max_iter = num_iter
109-
trainer.state.max_iters = num_iter
110109
trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length) # type: ignore[operator]
111110

112111
if not trainer.has_event_handler(self._reached_num_iterations):

mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,6 @@ ignore_missing_imports = True
7777

7878
[mypy-torchvision.*]
7979
ignore_missing_imports = True
80+
81+
[mypy-ignite.contrib.handlers.custom_events]
82+
ignore_errors = True

tests/ignite/engine/test_engine.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,47 +1029,6 @@ def switch_dataloader():
10291029

10301030
trainer.run(data1, max_epochs=10)
10311031

1032-
def test_run_with_max_iters(self):
1033-
max_iters = 8
1034-
engine = Engine(lambda e, b: 1)
1035-
engine.run([0] * 20, max_iters=max_iters)
1036-
assert engine.state.iteration == max_iters
1037-
assert engine.state.max_iters == max_iters
1038-
1039-
def test_run_with_max_iters_greater_than_epoch_length(self):
1040-
max_iters = 73
1041-
engine = Engine(lambda e, b: 1)
1042-
engine.run([0] * 20, max_iters=max_iters)
1043-
assert engine.state.iteration == max_iters
1044-
1045-
def test_run_with_invalid_max_iters_and_max_epoch(self):
1046-
max_iters = 12
1047-
max_epochs = 2
1048-
engine = Engine(lambda e, b: 1)
1049-
with pytest.raises(
1050-
ValueError,
1051-
match=r"Arguments max_iters and max_epochs are mutually exclusive."
1052-
"Please provide only max_epochs or max_iters.",
1053-
):
1054-
engine.run([0] * 20, max_iters=max_iters, max_epochs=max_epochs)
1055-
1056-
def test_epoch_events_fired_max_iters(self):
1057-
max_iters = 32
1058-
engine = Engine(lambda e, b: 1)
1059-
1060-
@engine.on(Events.EPOCH_COMPLETED)
1061-
def fired_event(engine):
1062-
assert engine.state.iteration % engine.state.epoch_length == 0
1063-
1064-
engine.run([0] * 10, max_iters=max_iters)
1065-
1066-
def test_is_done_with_max_iters(self):
1067-
state = State(iteration=100, epoch=1, max_epochs=3, epoch_length=100, max_iters=250)
1068-
assert not Engine._is_done(state)
1069-
1070-
state = State(iteration=250, epoch=1, max_epochs=3, epoch_length=100, max_iters=250)
1071-
assert Engine._is_done(state)
1072-
10731032
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
10741033
def test_batch_is_released_before_new_one_is_loaded_on_cuda(self):
10751034
torch.cuda.empty_cache()

tests/ignite/handlers/test_lr_finder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def test_num_iter_is_not_enough(lr_finder, to_save, dummy_engine, dataloader):
348348
trainer_with_finder.run(dataloader)
349349
assert_output_sizes(lr_finder, dummy_engine)
350350
assert dummy_engine.state.iteration != len(dataloader)
351-
assert dummy_engine.state.iteration == 150
351+
assert dummy_engine.state.iteration == 150 + 1
352352

353353

354354
def test_detach_terminates(lr_finder, to_save, dummy_engine, dataloader):

0 commit comments

Comments
 (0)